From 9473ec2147f06f1c9cdd44c0caac5811493abc68 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 18 Feb 2024 20:57:30 -0500 Subject: [PATCH] mamba : simultaneous sequence processing A batch can now contain tokens from multiple sequences. This is necessary for at least the parallel example, the server example, and the HellaSwag test in the perplexity example. However, for this to be useful, uses of llama_kv_cache_seq_rm/cp will need to be changed to work on whole sequences. * ggml : add ggml_ssm_conv as a new operator for the conv step of Mamba This operator makes it possible to use and update the correct states for each token of the batch in the same way as ggml_ssm_scan. Other solutions which use existing operators would need loops which would add too many nodes to the graph (at least the ones I thought of). Using this operator further reduces the size of the CPU compute buffer from 140.68 MiB to 103.20 MiB with Mamba 3B with a batch size of 512. And (at least on CPU), it's a bit faster than before. Note that "ggml_ssm_conv" is probably not the most appropriate name, and it could be changed if a better one is found. * llama : add inp_s_seq as a new input tensor The most convenient implementation to select the correct state (for Mamba) for each token is to directly get the correct index from a tensor. This is why inp_s_seq is storing int32_t and not floats. The other, less convenient way to select the correct state would be to have inp_KQ_mask contain 1.0f for each state used by a token and 0.0f otherwise. This complicates quickly fetching the first used state of a token, and is also less efficient because a whole row of the mask would always need to be read for each token. Using indexes makes it easy to stop searching when there are no more sequences for a token, and the first sequence assigned is always very quickly available (it's the first element of each row). --- ggml.c | 292 +++++++++++++++++++++++++++++++++++++++++++++--------- ggml.h | 11 +- llama.cpp | 136 ++++++++++++++----------- 3 files changed, 330 insertions(+), 109 deletions(-) diff --git a/ggml.c b/ggml.c index 3fa290f6f..981a2302a 100644 --- a/ggml.c +++ b/ggml.c @@ -1828,6 +1828,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "FLASH_ATTN", "FLASH_FF", "FLASH_ATTN_BACK", + "SSM_CONV", "SSM_SCAN", "WIN_PART", "WIN_UNPART", @@ -1851,7 +1852,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); +static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1915,6 +1916,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_attn(x)", "flash_ff(x)", "flash_attn_back(x)", + "ssm_conv(x)", "ssm_scan(x)", "win_part(x)", "win_unpart(x)", @@ -1938,7 +1940,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); +static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6079,6 +6081,51 @@ struct ggml_tensor * ggml_flash_attn_back( return result; } +// ggml_ssm_conv + +struct ggml_tensor * ggml_ssm_conv( + struct ggml_context * ctx, + struct ggml_tensor * s, + struct ggml_tensor * x, + struct ggml_tensor * c, + struct ggml_tensor * sq) { + GGML_ASSERT(ggml_is_3d(s)); + GGML_ASSERT(ggml_is_matrix(x)); + GGML_ASSERT(ggml_is_matrix(c)); + GGML_ASSERT(ggml_is_matrix(sq)); + GGML_ASSERT(sq->type == GGML_TYPE_I32); + + const int64_t d_conv = c->ne[0]; + const int64_t d_inner = c->ne[1]; + const int64_t n_tokens = x->ne[1]; + const int64_t n_kv = s->ne[2]; + + GGML_ASSERT( s->ne[0] == d_conv - 1); + GGML_ASSERT( s->ne[1] == d_inner); + GGML_ASSERT( x->ne[0] == d_inner); + GGML_ASSERT(sq->ne[0] == n_kv); + GGML_ASSERT(sq->ne[1] == n_tokens); + + bool is_node = false; + + if (s->grad || x->grad || c->grad || sq->grad) { + GGML_ASSERT(false); // TODO: implement + is_node = true; + } + + // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv} + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv)); + + result->op = GGML_OP_SSM_CONV; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = s; + result->src[1] = x; + result->src[2] = c; + result->src[3] = sq; + + return result; +} + // ggml_ssm_scan struct ggml_tensor * ggml_ssm_scan( @@ -6088,11 +6135,13 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C) { + struct ggml_tensor * C, + struct ggml_tensor * sq) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); + GGML_ASSERT(sq->type == GGML_TYPE_I32); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); GGML_ASSERT(ggml_are_same_shape(x, dt)); @@ -6113,7 +6162,7 @@ struct ggml_tensor * ggml_ssm_scan( bool is_node = false; - if (s->grad || x->grad || dt->grad || A->grad || B->grad) { + if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) { GGML_ASSERT(false); // TODO: implement is_node = true; } @@ -6129,6 +6178,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; + result->src[6] = sq; return result; } @@ -14646,6 +14696,135 @@ static void ggml_compute_forward_flash_attn_back( } } +// ggml_compute_forward_ssm_conv + +static void ggml_compute_forward_ssm_conv_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, // conv_state + const struct ggml_tensor * src1, // x + const struct ggml_tensor * src2, // conv1d.weight + const struct ggml_tensor * src3, // state_seq + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src2->ne[0]; // d_conv + const int nr = src0->ne[1]; // d_inner + const int n_t = src1->ne[1]; // n_tokens + const int n_kv = src0->ne[2]; // max number of sequences in the batch + + GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + GGML_ASSERT(src3->nb[0] == sizeof(int32_t)); + GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + // for use with the destination state offset between sequences + GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + const int ir = ir1 - ir0; + + if (n_kv > 1) { + // multiple sequences means it's hard to know when it's the first time a state is read, + // so copy them all over to the destination, just to be sure. + for (int i3 = 0; i3 < n_kv; ++i3) { + float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); + float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); + // can't use memcpy because of d_conv vs d_conv - 1 + for (int i1 = 0; i1 < ir; ++i1) { + for (int i0 = 0; i0 < nc - 1; ++i0) { + // copy s0 to last (d_conv - 1) columns of s + s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; + } + } + } + } + + for (int i2 = 0; i2 < n_t; ++i2) { + int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens} + float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv} + float * s0; // {d_conv - 1, d_inner, n_kv} + float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} + int ne0s0; + + GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + + // avoid needing to copy the state for the first token + if (i2 == 0) { + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv} + ne0s0 = src0->ne[0]; + } else { + // the source is the last (d_conv - 1) columns of the destination + s0 = s + 1; + ne0s0 = nc; + } + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // shift state left + for (int i0 = 0; i0 < nc - 1; ++i0) { + s[i0 + i1*nc] = s0[i0 + i1*ne0s0]; + } + // insert x on the last column + s[(nc - 1) + i1*nc] = x0[i1]; + } + + // handle copies when there are multiple output states + for (int i3 = 1; i3 < n_kv; ++i3) { + int32_t seq = sq[i3]; + if (0 <= seq && seq < n_kv) { + float * s1 = s + (seq - sq[0])*nc*nr; + memcpy(s1, s, nc*ir*sizeof(float)); + } else { + // stop at negative or too big seq_ids + break; + } + } + + // it seems a little faster when this is separate from the state shift + for (int i1 = 0; i1 < ir; ++i1) { + // rowwise dot product + float sumf = 0.0f; + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + sumf += s[i] * c[i]; + } + x[i1] = sumf; + } + } +} + +static void ggml_compute_forward_ssm_conv( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * src2, + const struct ggml_tensor * src3, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_conv_f32(params, src0, src1, src2, src3, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_ssm_scan static void ggml_compute_forward_ssm_scan_f32( @@ -14656,6 +14835,7 @@ static void ggml_compute_forward_ssm_scan_f32( const struct ggml_tensor * src3, // A const struct ggml_tensor * src4, // B const struct ggml_tensor * src5, // C + const struct ggml_tensor * src6, // sq struct ggml_tensor * dst) { if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; @@ -14664,9 +14844,10 @@ static void ggml_compute_forward_ssm_scan_f32( const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = src1->ne[1]; // number of tokens in the batch + const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -14675,9 +14856,11 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C + // required for the dot product between s and C, and when copying the states GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required to get correct offset for state destination + // required for per-sequence offsets for states + GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[2]) GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float)); // rows per thread @@ -14688,16 +14871,37 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - // first token in the batch - { - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv} - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner, n_kv} - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data); // {d_state, n_tokens} + if (n_kv > 1) { + // it's hard to know if the source states have already been copied + // when there are multiple, so copy them already. + for (int i3 = 0; i3 < n_kv; ++i3) { + float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); + memcpy(s, s0, nc*ir*sizeof(float)); + } + } + + for (int i2 = 0; i2 < n_t; ++i2) { + int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens} + float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv} + float * s0; + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} + float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} + float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} + + GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + + // avoid needing to copy the state for the first token + if (i2 == 0) { + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv} + } else { + // otherwise the source is the same as the destination + s0 = s; + } + // d_inner for (int i1 = 0; i1 < ir; ++i1) { float dt_soft_plus = log1pf(expf(dt[i1])); @@ -14707,41 +14911,24 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i0 = 0; i0 < nc; ++i0) { int i = i0 + i1*nc; // state = prev_state * dA + dB * x - float state = s0[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); // y = rowwise_dotprod(state, C) - sumf += state*C[i0]; - // FIXME: handle simultaneous sequences + sumf += state * C[i0]; s[i] = state; } y[i1] = sumf; } - } - // rest of the batch, state comes from previous one which was stored in destination - for (int i2 = 1; i2 < n_t; ++i2) { - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv} - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - float dt_soft_plus = log1pf(expf(dt[i1])); - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state*C[i0]; - // FIXME: handle simultaneous sequences - s[i] = state; + // handle copies when there are multiple output states + for (int i3 = 1; i3 < n_kv; ++i3) { + int32_t seq = sq[i3]; + if (0 <= seq && seq < n_kv) { + float * s1 = s + (seq - sq[0])*nc*nr; + memcpy(s1, s, nc*ir*sizeof(float)); + } else { + // stop at negative or too big seq_ids + break; } - y[i1] = sumf; } } } @@ -14754,11 +14941,12 @@ static void ggml_compute_forward_ssm_scan( const struct ggml_tensor * src3, const struct ggml_tensor * src4, const struct ggml_tensor * src5, + const struct ggml_tensor * src6, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, dst); + ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, src6, dst); } break; default: { @@ -15818,9 +16006,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm bool masked = t != 0; ggml_compute_forward_flash_attn_back(params, masked, tensor); } break; + case GGML_OP_SSM_CONV: + { + ggml_compute_forward_ssm_conv(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_SSM_SCAN: { - ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor); + ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor->src[6], tensor); } break; case GGML_OP_WIN_PART: { @@ -16868,6 +17060,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // not supported } break; + case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: { GGML_ASSERT(false); // TODO: not implemented @@ -17569,6 +17762,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = n_threads; } break; + case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: { n_tasks = n_threads; diff --git a/ggml.h b/ggml.h index fdf251911..6d5cf7696 100644 --- a/ggml.h +++ b/ggml.h @@ -460,6 +460,7 @@ extern "C" { GGML_OP_FLASH_ATTN, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, + GGML_OP_SSM_CONV, GGML_OP_SSM_SCAN, GGML_OP_WIN_PART, GGML_OP_WIN_UNPART, @@ -1702,6 +1703,13 @@ extern "C" { struct ggml_tensor * c0, struct ggml_tensor * c1); + GGML_API struct ggml_tensor * ggml_ssm_conv( + struct ggml_context * ctx, + struct ggml_tensor * s, + struct ggml_tensor * x, + struct ggml_tensor * c, + struct ggml_tensor * sq); + GGML_API struct ggml_tensor * ggml_ssm_scan( struct ggml_context * ctx, struct ggml_tensor * s, @@ -1709,7 +1717,8 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C); + struct ggml_tensor * C, + struct ggml_tensor * sq); // partition into non-overlapping windows with padding if needed // example: diff --git a/llama.cpp b/llama.cpp index 2613340cc..ad0322674 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2044,6 +2044,7 @@ struct llama_context { struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba) + struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch] #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; @@ -4761,8 +4762,9 @@ static bool llm_load_tensors( const int64_t d_conv = hparams.n_embd_head_k + 1; const int64_t d_state = hparams.n_embd_head_v; const int64_t d_inner = hparams.n_head; - // FIXME: ceiling instead of floor - const int64_t dt_rank = n_embd / 16; + // TODO: allow loading dt_rank from the model config + // ceiling division + const int64_t dt_rank = (n_embd / 16) + (n_embd % 16 > 0); GGML_ASSERT(2 * n_embd == d_inner); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -8012,13 +8014,12 @@ struct llm_build_context { GGML_ASSERT(2 * d_model == d_inner); const int64_t d_conv = n_embd_head_k + 1; const int64_t d_state = n_embd_head_v; - const int64_t dt_rank = d_model / 16; + // ceiling division + const int64_t dt_rank = (d_model / 16) + (d_model % 16 > 0); struct ggml_tensor * cur; struct ggml_tensor * inpL; - GGML_ASSERT(kv_self.used - kv_self.head + 1 == 1); // TODO: support more than one sequence per batch - // {n_embd, n_tokens} inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); cb(inpL, "inp_embd", -1); @@ -8040,8 +8041,8 @@ struct llm_build_context { state_mask); } - struct ggml_tensor * conv_state = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv); - struct ggml_tensor * ssm_state = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv); + conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv); + ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv); // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -8056,37 +8057,31 @@ struct llm_build_context { struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0); + // conv { - // concat last (d_conv - 1) columns of conv_state, and x + // Custom operator which is needed only to ease simultaneous sequence processing. + // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, + // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weigth, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // The new conv_states is the last (d_conv - 1) columns + // of the last 3rd dimensional "layer" of the self-overlapping view. + // For simultaneous sequences, it's more complicated. + struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq); - // The following tensor is too big in order to avoid an assertion error when making an overlapping view. - // TODO: in ggml_new_tensor_impl, handle overlapping data range in data size calculation - // This could then be a tensor with ne[] = {(d_conv-1)+n_tokens, d_inner}, - // but the size difference is not that big (d_conv is usually 4). - struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tokens); - const size_t conv_x_nb1 = (d_conv - 1 + n_tokens) * ggml_element_size(conv_x); - - conv_x = ggml_set_2d(ctx0, conv_x, conv_state, conv_x_nb1, 0); - // making x contiguous is necessary because ggml_set expects it - conv_x = ggml_set_2d(ctx0, conv_x, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_x_nb1, (d_conv - 1)*ggml_element_size(conv_x)); - - // store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state + // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tokens*ggml_element_size(conv_x)), - ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_x)))); + ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), + ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); - // prepare convolution for all tokens in the batch with a self-overlapping view, - // shifting by one column each ... depth? ... with a window of d_conv columns. - // {(d_conv-1)+n_tokens, d_inner} => {d_conv, d_inner, n_tokens} - conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tokens, conv_x_nb1, 1*ggml_element_size(conv_x), 0); - - // perform convolution - // => {1, d_inner, n_tokens} - x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_x, model.layers[il].ssm_conv1d)); - // => {d_inner, n_tokens, 1} - x = ggml_permute(ctx0, x, 2, 0, 1, 3); + // extract x from x_conv + x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); // bias x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); @@ -8111,13 +8106,13 @@ struct llm_build_context { // as described in the Annex D of the Mamba paper. // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined, // because only a single tensor can be returned. - struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B, C); + struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); // store last states (the second part of y_ssm_states) ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_state)))); + ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_states)))); struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); @@ -8362,7 +8357,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; - // For Transformers, use only the previous KV cells + // For Transformers, use only the previous KV cells (or all, when non-causal) // of the correct sequence for each token of the batch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. for (int h = 0; h < 1; ++h) { @@ -8382,13 +8377,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } - // For Mamba (and other constant-time-and-size architectures), - // update the correct state(s)/sequence(s) for each token of the batch. - // Source and destination states are both the same for the sake of implementation simplicity. - // It would be more complex if they were sometimes the same and somtimes not. - // (with Transformers, source KV cells are never the destination, - // which is also simpler, but more memory hungry) - // TODO: implement } if (hparams.need_kq_pos) { @@ -8447,28 +8435,54 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } if (kv_self.unlimited) { - const uint32_t kv_size = kv_self.size; - const uint32_t n_kv = kv_self.n; + const int64_t kv_size = kv_self.size; + const int64_t n_kv = kv_self.n; - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); - float * data = (float *) lctx.inp_s_mask->data; + { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); + float * data = (float *) lctx.inp_s_mask->data; - // states which are not affected by the current batch are left untouched - for (uint32_t i = 0; i < n_kv; ++i) { - llama_seq_id seq_id = i + lctx.kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; - bool has_self_seq = kv_cell.has_seq_id(seq_id); + // states which are not affected by the current batch are left untouched + for (int i = 0; i < n_kv; ++i) { + llama_seq_id seq_id = i + lctx.kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; + bool has_self_seq = kv_cell.has_seq_id(seq_id); - data[i] = (float) has_self_seq; + data[i] = (float) has_self_seq; - // ensure current sequences will be kept - if (!has_self_seq) { - kv_cell.seq_id.insert(seq_id); + // ensure current sequences will be kept + if (!has_self_seq) { + kv_cell.seq_id.insert(seq_id); + } + } + } + // For Mamba (and other constant-time-and-size architectures), + // update the correct state(s)/sequence(s) for each token of the batch. + // Like with the KQ_mask, if a token in the batch has multiple sequences, + // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). + { + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); + int32_t * data = (int32_t *) lctx.inp_s_seq->data; + + for (int j = 0; j < n_tokens; ++j) { + const int32_t n_seq = batch.n_seq_id[j]; + GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence + + for (int i = 0; i < n_kv; ++i) { + if (i < n_seq) { + // for this type of model, the head is the minimum seq_id of the batch + data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head; + } else { + data[j*n_kv + i] = -1; + } + } } } // remove extraneous seq_ids when state copies are made { - for (uint32_t i = 0; i < kv_size; ++i) { + for (int i = 0; i < kv_size; ++i) { llama_kv_cell & kv_cell = lctx.kv_self.cells[i]; uint32_t n_seqs = kv_cell.seq_id.size(); bool has_self_seq = kv_cell.has_seq_id(i); @@ -12642,7 +12656,7 @@ struct llama_context * llama_new_context_with_model( // graph inputs { ggml_init_params init_params = { - /* .mem_size */ ggml_tensor_overhead()*(8 + ctx->kv_self.unlimited), + /* .mem_size */ ggml_tensor_overhead()*(8 + 2*(ctx->kv_self.unlimited)), /* .mem_buffer */ nullptr, /* .no_alloc */ true, }; @@ -12656,8 +12670,10 @@ struct llama_context * llama_new_context_with_model( ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size); ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); - if (ctx->kv_self.unlimited) + if (ctx->kv_self.unlimited) { ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size); + ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch); + } ggml_set_name(ctx->inp_tokens, "inp_tokens"); ggml_set_name(ctx->inp_embd, "inp_embd"); @@ -12667,8 +12683,10 @@ struct llama_context * llama_new_context_with_model( ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); ggml_set_name(ctx->inp_mean, "inp_mean"); ggml_set_name(ctx->inp_cls, "inp_cls"); - if (ctx->kv_self.unlimited) + if (ctx->kv_self.unlimited) { ggml_set_name(ctx->inp_s_mask, "inp_s_mask"); + ggml_set_name(ctx->inp_s_seq, "inp_s_seq"); + } ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,