mamba : refactor recurrent conv, resulting in 20% perf increase

It's still slower than I'd like, but I did not really optimize `ggml_exp` yet.

I also refactored `ggml_exp` to work with tensors with more than 2 dimensions.
This commit is contained in:
Francis Couture-Harpin 2024-01-29 10:21:19 -05:00
parent 74eea856bf
commit 9e77061a3b
2 changed files with 23 additions and 22 deletions

17
ggml.c
View file

@ -8678,16 +8678,19 @@ static void ggml_compute_forward_exp_f32(
return; return;
} }
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
GGML_ASSERT( dst->nb[0] == sizeof(float)); GGML_ASSERT( dst->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) { GGML_TENSOR_UNARY_OP_LOCALS
ggml_vec_exp_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])), for (int64_t i3 = 0; i3 < ne03; i3++) {
(float *) ((char *) src0->data + i*(src0->nb[1]))); for (int64_t i2 = 0; i2 < ne02; i2++) {
for (int64_t i1 = 0; i1 < ne01; i1++) {
float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
ggml_vec_exp_f32(ne00, dst_row, src_row);
}
}
} }
} }

View file

@ -7929,15 +7929,14 @@ struct llm_build_context {
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
// NOTE: not sure what's the difference between the sequence length and the batch size in the paper.
// {n_embd, batch} // {n_embd, batch}
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
cb(inpL, "inp_embd", -1); cb(inpL, "inp_embd", -1);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
// (ab)using the kv cache to store the state // (ab)using the kv cache to store the state
// NOTE: the conv_state is transposed to ease shifting it. ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv, d_inner);
// if you figured out a way to shift it without transposing it like this, go ahead and fix this.
ggml_tensor * conv_state = kv_self.k_l[il]; // {d_inner, d_conv}
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner); ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);
// norm // norm
@ -7946,33 +7945,32 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
// {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner} // {n_embd, 2*d_inner} * {n_embd, batch} = {2*d_inner, batch}
struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in); struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
// split the above in two // split the above in two
// assuming it's contiguous // assuming it's contiguous
// FIXME: handle batches of more than 1 token // {d_inner, batch}
struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0); struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, ggml_element_size(xz)*d_inner); struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
cur = x; cur = x;
// conv // conv
{ {
// shift conv state left // shift conv state left
conv_state = ggml_set_1d(ctx0, conv_state, ggml_view_1d(ctx0, conv_state, (d_conv - 1)*d_inner, ggml_element_size(conv_state)*d_inner), 0); conv_state = ggml_set_2d(ctx0, conv_state, ggml_view_2d(ctx0, conv_state, (d_conv - 1), d_inner, conv_state->nb[1], ggml_element_size(conv_state)*1), conv_state->nb[1], 0);
// update last column // update last column
conv_state = ggml_set_1d(ctx0, conv_state, x, ggml_element_size(conv_state)*(d_conv - 1)*d_inner); // x here is {d_inner, 1} (a row), but should be {1, d_inner} (a column)
conv_state = ggml_set_2d(ctx0, conv_state, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_state->nb[1], ggml_element_size(conv_state)*(d_conv - 1));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state, ggml_view_tensor(ctx0, kv_self.k_l[il]))); ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state, ggml_view_tensor(ctx0, kv_self.k_l[il])));
// rearrange and sum // rearrange and sum
conv_state = ggml_reshape_2d(ctx0, conv_state, d_inner, d_conv); // no need to rearrange the conv_state, since it's already in the right shape
// TODO: find a way to directly shift a 2d conv_state, avoiding the need to transpose here. // => {1, d_inner}
conv_state = ggml_cont(ctx0, ggml_transpose(ctx0, conv_state));
// --> {1, d_inner}
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_state, model.layers[il].ssm_conv1d)); x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_state, model.layers[il].ssm_conv1d));
// => {d_inner, 1}
x = ggml_transpose(ctx0, x); x = ggml_transpose(ctx0, x);
// bias // bias