From 9e77061a3b4d719ecabe68a930aa102f478a34fd Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 29 Jan 2024 10:21:19 -0500 Subject: [PATCH] mamba : refactor recurrent conv, resulting in 20% perf increase It's still slower than I'd like, but I did not really optimize `ggml_exp` yet. I also refactored `ggml_exp` to work with tensors with more than 2 dimensions. --- ggml.c | 17 ++++++++++------- llama.cpp | 28 +++++++++++++--------------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/ggml.c b/ggml.c index dcc70d010..a5b337d96 100644 --- a/ggml.c +++ b/ggml.c @@ -8678,16 +8678,19 @@ static void ggml_compute_forward_exp_f32( return; } - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - GGML_ASSERT( dst->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { - ggml_vec_exp_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); + GGML_TENSOR_UNARY_OP_LOCALS + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + ggml_vec_exp_f32(ne00, dst_row, src_row); + } + } } } diff --git a/llama.cpp b/llama.cpp index fceee6317..048bd8e50 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7929,15 +7929,14 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; + // NOTE: not sure what's the difference between the sequence length and the batch size in the paper. // {n_embd, batch} inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); cb(inpL, "inp_embd", -1); for (int il = 0; il < n_layer; ++il) { // (ab)using the kv cache to store the state - // NOTE: the conv_state is transposed to ease shifting it. - // if you figured out a way to shift it without transposing it like this, go ahead and fix this. - ggml_tensor * conv_state = kv_self.k_l[il]; // {d_inner, d_conv} + ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv, d_inner); ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner); // norm @@ -7946,33 +7945,32 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - // {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner} - struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in); + // {n_embd, 2*d_inner} * {n_embd, batch} = {2*d_inner, batch} + struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur); // split the above in two // assuming it's contiguous - // FIXME: handle batches of more than 1 token - struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0); - struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, ggml_element_size(xz)*d_inner); + // {d_inner, batch} + struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); + struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); cur = x; // conv { // shift conv state left - conv_state = ggml_set_1d(ctx0, conv_state, ggml_view_1d(ctx0, conv_state, (d_conv - 1)*d_inner, ggml_element_size(conv_state)*d_inner), 0); + conv_state = ggml_set_2d(ctx0, conv_state, ggml_view_2d(ctx0, conv_state, (d_conv - 1), d_inner, conv_state->nb[1], ggml_element_size(conv_state)*1), conv_state->nb[1], 0); // update last column - conv_state = ggml_set_1d(ctx0, conv_state, x, ggml_element_size(conv_state)*(d_conv - 1)*d_inner); + // x here is {d_inner, 1} (a row), but should be {1, d_inner} (a column) + conv_state = ggml_set_2d(ctx0, conv_state, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_state->nb[1], ggml_element_size(conv_state)*(d_conv - 1)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state, ggml_view_tensor(ctx0, kv_self.k_l[il]))); // rearrange and sum - conv_state = ggml_reshape_2d(ctx0, conv_state, d_inner, d_conv); - // TODO: find a way to directly shift a 2d conv_state, avoiding the need to transpose here. - conv_state = ggml_cont(ctx0, ggml_transpose(ctx0, conv_state)); - - // --> {1, d_inner} + // no need to rearrange the conv_state, since it's already in the right shape + // => {1, d_inner} x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_state, model.layers[il].ssm_conv1d)); + // => {d_inner, 1} x = ggml_transpose(ctx0, x); // bias