From a3f4a1c7dc9fc10082d5290b49505bc3d3db239c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 3 Feb 2024 17:49:36 -0500 Subject: [PATCH] mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator This increases performance on CPU by around 30% for prompt processing, and by around 20% for text generation. However, it also makes the ggml_exp and ggml_soft_plus operators unused. Whether or not they should be kept will be decided later. --- ggml.c | 115 ++++++++++++++++++++++++++++++++++++------------------ ggml.h | 6 ++- llama.cpp | 49 +++-------------------- 3 files changed, 88 insertions(+), 82 deletions(-) diff --git a/ggml.c b/ggml.c index b132bec68..90dcddbb7 100644 --- a/ggml.c +++ b/ggml.c @@ -6156,31 +6156,45 @@ struct ggml_tensor * ggml_flash_attn_back( struct ggml_tensor * ggml_ssm_scan( struct ggml_context * ctx, struct ggml_tensor * s, - struct ggml_tensor * dA, - struct ggml_tensor * dB_x) { - GGML_ASSERT(ggml_are_same_shape(dA, dB_x)); - - GGML_ASSERT( s->nb[0] == ggml_type_size( s->type)); - GGML_ASSERT( dA->nb[0] == ggml_type_size( dA->type)); - GGML_ASSERT(dB_x->nb[0] == ggml_type_size(dB_x->type)); - - GGML_ASSERT(s->ne[0] == dA->ne[0]); - GGML_ASSERT(s->ne[1] == dA->ne[1]); + struct ggml_tensor * x, + struct ggml_tensor * dt, + struct ggml_tensor * A, + struct ggml_tensor * B) { + GGML_ASSERT(ggml_is_contiguous(s)); + GGML_ASSERT(ggml_is_contiguous(x)); + GGML_ASSERT(ggml_is_contiguous(dt)); + GGML_ASSERT(ggml_is_contiguous(A)); + GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); + ggml_are_same_shape(x, dt); GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D + { + const int64_t d_state = s->ne[0]; + const int64_t d_inner = s->ne[1]; + const int64_t n_tok = x->ne[1]; + + GGML_ASSERT(x->ne[0] == d_inner); + GGML_ASSERT(A->ne[0] == d_state); + GGML_ASSERT(A->ne[1] == d_inner); + GGML_ASSERT(B->ne[0] == d_state); + GGML_ASSERT(B->ne[1] == n_tok); + } + bool is_node = false; - if (s->grad || dA->grad || dB_x->grad) { + if (s->grad || x->grad || dt->grad || A->grad || B->grad) { is_node = true; } - struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, dA->ne[0], dA->ne[1], dA->ne[2]); + struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, s->ne[0], s->ne[1], x->ne[1]); result->op = GGML_OP_SSM_SCAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = s; - result->src[1] = dA; - result->src[2] = dB_x; + result->src[1] = x; + result->src[2] = dt; + result->src[3] = A; + result->src[4] = B; return result; } @@ -14795,9 +14809,11 @@ static void ggml_compute_forward_flash_attn_back( static void ggml_compute_forward_ssm_scan_f32( const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - const struct ggml_tensor * src2, + const struct ggml_tensor * src0, // s + const struct ggml_tensor * src1, // x + const struct ggml_tensor * src2, // dt + const struct ggml_tensor * src3, // A + const struct ggml_tensor * src4, // B struct ggml_tensor * dst) { if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; @@ -14806,18 +14822,19 @@ static void ggml_compute_forward_ssm_scan_f32( const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src1->ne[0]; - const int64_t n_t = src1->ne[2]; // number of tokens in the batch + const int64_t nc = src0->ne[0]; + const int64_t n_t = src1->ne[1]; // number of tokens in the batch const int64_t nr0 = ggml_nrows(src0); - GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(src1)); + GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); + GGML_ASSERT(src3->nb[0] == sizeof(float)); + GGML_ASSERT(src4->nb[0] == sizeof(float)); // allow merging multiple rows in the same vec operation GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - GGML_ASSERT(src1->nb[1] == src1->ne[0]*sizeof(float)); - GGML_ASSERT(src2->nb[1] == src2->ne[0]*sizeof(float)); + GGML_ASSERT(src3->nb[1] == src3->ne[0]*sizeof(float)); // rows per thread const int dr = (nr0 + nth - 1)/nth; @@ -14829,22 +14846,44 @@ static void ggml_compute_forward_ssm_scan_f32( // first batch { - float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); - float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); - float * dA = (float *) ((char *) src1->data + ir0*(src1->nb[1])); - float * dB_x = (float *) ((char *) src2->data + ir0*(src2->nb[1])); - ggml_vec_mul_f32(nc*ir, dest, s, dA); - ggml_vec_add_f32(nc*ir, dest, dest, dB_x); + float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok} + float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner} + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok} + float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + float * B = (float *) ((char *) src4->data); // {d_state, n_tok} + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + float dt_soft_plus = log1pf(expf(dt[i1])); + float x_dt = x[i1] * dt_soft_plus; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // ssm_state * dA + dB * x + dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + } + } } // compute state for rest of tokens, previous state comes from dest - for (int i2 = 1; i2 < n_t; i2++) { - float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); - float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); - float * dA = (float *) ((char *) src1->data + ir0*(src1->nb[1]) + i2 *(src1->nb[2])); - float * dB_x = (float *) ((char *) src2->data + ir0*(src2->nb[1]) + i2 *(src2->nb[2])); - ggml_vec_mul_f32(nc*ir, dest, s, dA); - ggml_vec_add_f32(nc*ir, dest, dest, dB_x); + for (int i2 = 1; i2 < n_t; ++i2) { + float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok} + float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok} + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok} + float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tok} + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + float dt_soft_plus = log1pf(expf(dt[i1])); + float x_dt = x[i1] * dt_soft_plus; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // ssm_state * dA + dB * x + dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + } + } } } @@ -14853,11 +14892,13 @@ static void ggml_compute_forward_ssm_scan( const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * src2, + const struct ggml_tensor * src3, + const struct ggml_tensor * src4, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, dst); + ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, dst); } break; default: { @@ -15927,7 +15968,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_SSM_SCAN: { - ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor); } break; case GGML_OP_WIN_PART: { diff --git a/ggml.h b/ggml.h index 0a40f8762..3a4c9201a 100644 --- a/ggml.h +++ b/ggml.h @@ -1724,8 +1724,10 @@ extern "C" { GGML_API struct ggml_tensor * ggml_ssm_scan( struct ggml_context * ctx, struct ggml_tensor * s, - struct ggml_tensor * dA, - struct ggml_tensor * dB_x); + struct ggml_tensor * x, + struct ggml_tensor * dt, + struct ggml_tensor * A, + struct ggml_tensor * B); // partition into non-overlapping windows with padding if needed // example: diff --git a/llama.cpp b/llama.cpp index 9dba8eeb2..466f8bc0c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7999,55 +7999,18 @@ struct llm_build_context { struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x); // split struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tok, x_db->nb[1], 0); - struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); + struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); // {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok} dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); - dt = ggml_soft_plus(ctx0, dt); - struct ggml_tensor * dA; - struct ggml_tensor * dB; - if (n_tok == 1) { - // => {d_state, d_inner} - dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt))); - - // {d_state} * {d_inner} => {d_state, d_inner} - dB = ggml_out_prod(ctx0, B, dt); - } else { - // {d_state, d_inner} * {d_inner, n_tok} => {d_state, d_inner, n_tok} * {1, d_inner, n_tok} - // => {d_state, d_inner, n_tok} - // Trying to do the equivalent of - // dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) - struct ggml_tensor * A = model.layers[il].ssm_a; - dA = ggml_exp(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, A, ggml_new_tensor_3d(ctx0, A->type, d_state, d_inner, n_tok)), - // {d_inner, n_tok} => {1, d_inner, n_tok} - ggml_permute(ctx0, dt, 1, 2, 0, 3)) - ); - - // {d_state, 1, n_tok} * {d_inner, 1, n_tok} => {d_state, d_inner, n_tok} - dB = ggml_out_prod(ctx0, - // {d_state, n_tok} => {d_state, 1, n_tok} - ggml_permute(ctx0, B, 0, 2, 1, 3), - // {d_state, n_tok} => {d_state, 1, n_tok} - ggml_permute(ctx0, dt, 0, 2, 1, 3)); - } - - // {d_state, d_inner, n_tok} * {1, d_inner, n_tok} => {d_state, d_inner, n_tok} - cur = ggml_mul(ctx0, dB, ggml_permute(ctx0, x, 1, 2, 0, 3)); - - // The selective scan seems inherently sequential... - // To avoid making (n_layer * n_tok) graph nodes, let's use a custom operator. - // When n_tok == 1, it's equivalent to the following: - // ssm_state = ggml_add(ctx0, ggml_mul(ctx0, ssm_state, dA), cur); - // When n_tok is bigger, it's the same thing, but iterated n_tok times, - // with the correct dA and cur for each token. - // The resulting states are layered on the ne[2] dimension. + // Custom operator to implement some of the optimizations + // described in the Annex D of the Mamba paper. + // TODO: maybe also optimize step 4 of the Speed section of Annex D (the mul_mat with C) // => {d_state, d_inner, n_tok} - ssm_state = ggml_ssm_scan(ctx0, ssm_state, dA, cur); + ssm_state = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B); // only store last state ggml_build_forward_expand(gf,