diff --git a/ggml.c b/ggml.c index 3fa290f6f..981a2302a 100644 --- a/ggml.c +++ b/ggml.c @@ -1828,6 +1828,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "FLASH_ATTN", "FLASH_FF", "FLASH_ATTN_BACK", + "SSM_CONV", "SSM_SCAN", "WIN_PART", "WIN_UNPART", @@ -1851,7 +1852,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); +static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1915,6 +1916,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_attn(x)", "flash_ff(x)", "flash_attn_back(x)", + "ssm_conv(x)", "ssm_scan(x)", "win_part(x)", "win_unpart(x)", @@ -1938,7 +1940,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); +static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6079,6 +6081,51 @@ struct ggml_tensor * ggml_flash_attn_back( return result; } +// ggml_ssm_conv + +struct ggml_tensor * ggml_ssm_conv( + struct ggml_context * ctx, + struct ggml_tensor * s, + struct ggml_tensor * x, + struct ggml_tensor * c, + struct ggml_tensor * sq) { + GGML_ASSERT(ggml_is_3d(s)); + GGML_ASSERT(ggml_is_matrix(x)); + GGML_ASSERT(ggml_is_matrix(c)); + GGML_ASSERT(ggml_is_matrix(sq)); + GGML_ASSERT(sq->type == GGML_TYPE_I32); + + const int64_t d_conv = c->ne[0]; + const int64_t d_inner = c->ne[1]; + const int64_t n_tokens = x->ne[1]; + const int64_t n_kv = s->ne[2]; + + GGML_ASSERT( s->ne[0] == d_conv - 1); + GGML_ASSERT( s->ne[1] == d_inner); + GGML_ASSERT( x->ne[0] == d_inner); + GGML_ASSERT(sq->ne[0] == n_kv); + GGML_ASSERT(sq->ne[1] == n_tokens); + + bool is_node = false; + + if (s->grad || x->grad || c->grad || sq->grad) { + GGML_ASSERT(false); // TODO: implement + is_node = true; + } + + // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv} + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv)); + + result->op = GGML_OP_SSM_CONV; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = s; + result->src[1] = x; + result->src[2] = c; + result->src[3] = sq; + + return result; +} + // ggml_ssm_scan struct ggml_tensor * ggml_ssm_scan( @@ -6088,11 +6135,13 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C) { + struct ggml_tensor * C, + struct ggml_tensor * sq) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); + GGML_ASSERT(sq->type == GGML_TYPE_I32); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); GGML_ASSERT(ggml_are_same_shape(x, dt)); @@ -6113,7 +6162,7 @@ struct ggml_tensor * ggml_ssm_scan( bool is_node = false; - if (s->grad || x->grad || dt->grad || A->grad || B->grad) { + if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) { GGML_ASSERT(false); // TODO: implement is_node = true; } @@ -6129,6 +6178,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; + result->src[6] = sq; return result; } @@ -14646,6 +14696,135 @@ static void ggml_compute_forward_flash_attn_back( } } +// ggml_compute_forward_ssm_conv + +static void ggml_compute_forward_ssm_conv_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, // conv_state + const struct ggml_tensor * src1, // x + const struct ggml_tensor * src2, // conv1d.weight + const struct ggml_tensor * src3, // state_seq + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src2->ne[0]; // d_conv + const int nr = src0->ne[1]; // d_inner + const int n_t = src1->ne[1]; // n_tokens + const int n_kv = src0->ne[2]; // max number of sequences in the batch + + GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + GGML_ASSERT(src3->nb[0] == sizeof(int32_t)); + GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + // for use with the destination state offset between sequences + GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + const int ir = ir1 - ir0; + + if (n_kv > 1) { + // multiple sequences means it's hard to know when it's the first time a state is read, + // so copy them all over to the destination, just to be sure. + for (int i3 = 0; i3 < n_kv; ++i3) { + float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); + float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); + // can't use memcpy because of d_conv vs d_conv - 1 + for (int i1 = 0; i1 < ir; ++i1) { + for (int i0 = 0; i0 < nc - 1; ++i0) { + // copy s0 to last (d_conv - 1) columns of s + s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; + } + } + } + } + + for (int i2 = 0; i2 < n_t; ++i2) { + int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens} + float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv} + float * s0; // {d_conv - 1, d_inner, n_kv} + float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} + int ne0s0; + + GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + + // avoid needing to copy the state for the first token + if (i2 == 0) { + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv} + ne0s0 = src0->ne[0]; + } else { + // the source is the last (d_conv - 1) columns of the destination + s0 = s + 1; + ne0s0 = nc; + } + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // shift state left + for (int i0 = 0; i0 < nc - 1; ++i0) { + s[i0 + i1*nc] = s0[i0 + i1*ne0s0]; + } + // insert x on the last column + s[(nc - 1) + i1*nc] = x0[i1]; + } + + // handle copies when there are multiple output states + for (int i3 = 1; i3 < n_kv; ++i3) { + int32_t seq = sq[i3]; + if (0 <= seq && seq < n_kv) { + float * s1 = s + (seq - sq[0])*nc*nr; + memcpy(s1, s, nc*ir*sizeof(float)); + } else { + // stop at negative or too big seq_ids + break; + } + } + + // it seems a little faster when this is separate from the state shift + for (int i1 = 0; i1 < ir; ++i1) { + // rowwise dot product + float sumf = 0.0f; + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + sumf += s[i] * c[i]; + } + x[i1] = sumf; + } + } +} + +static void ggml_compute_forward_ssm_conv( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * src2, + const struct ggml_tensor * src3, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_conv_f32(params, src0, src1, src2, src3, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_ssm_scan static void ggml_compute_forward_ssm_scan_f32( @@ -14656,6 +14835,7 @@ static void ggml_compute_forward_ssm_scan_f32( const struct ggml_tensor * src3, // A const struct ggml_tensor * src4, // B const struct ggml_tensor * src5, // C + const struct ggml_tensor * src6, // sq struct ggml_tensor * dst) { if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; @@ -14664,9 +14844,10 @@ static void ggml_compute_forward_ssm_scan_f32( const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = src1->ne[1]; // number of tokens in the batch + const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -14675,9 +14856,11 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C + // required for the dot product between s and C, and when copying the states GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required to get correct offset for state destination + // required for per-sequence offsets for states + GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[2]) GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float)); // rows per thread @@ -14688,16 +14871,37 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - // first token in the batch - { - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv} - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner, n_kv} - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data); // {d_state, n_tokens} + if (n_kv > 1) { + // it's hard to know if the source states have already been copied + // when there are multiple, so copy them already. + for (int i3 = 0; i3 < n_kv; ++i3) { + float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); + memcpy(s, s0, nc*ir*sizeof(float)); + } + } + + for (int i2 = 0; i2 < n_t; ++i2) { + int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens} + float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv} + float * s0; + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} + float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} + float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} + + GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + + // avoid needing to copy the state for the first token + if (i2 == 0) { + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv} + } else { + // otherwise the source is the same as the destination + s0 = s; + } + // d_inner for (int i1 = 0; i1 < ir; ++i1) { float dt_soft_plus = log1pf(expf(dt[i1])); @@ -14707,41 +14911,24 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i0 = 0; i0 < nc; ++i0) { int i = i0 + i1*nc; // state = prev_state * dA + dB * x - float state = s0[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); // y = rowwise_dotprod(state, C) - sumf += state*C[i0]; - // FIXME: handle simultaneous sequences + sumf += state * C[i0]; s[i] = state; } y[i1] = sumf; } - } - // rest of the batch, state comes from previous one which was stored in destination - for (int i2 = 1; i2 < n_t; ++i2) { - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv} - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - float dt_soft_plus = log1pf(expf(dt[i1])); - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state*C[i0]; - // FIXME: handle simultaneous sequences - s[i] = state; + // handle copies when there are multiple output states + for (int i3 = 1; i3 < n_kv; ++i3) { + int32_t seq = sq[i3]; + if (0 <= seq && seq < n_kv) { + float * s1 = s + (seq - sq[0])*nc*nr; + memcpy(s1, s, nc*ir*sizeof(float)); + } else { + // stop at negative or too big seq_ids + break; } - y[i1] = sumf; } } } @@ -14754,11 +14941,12 @@ static void ggml_compute_forward_ssm_scan( const struct ggml_tensor * src3, const struct ggml_tensor * src4, const struct ggml_tensor * src5, + const struct ggml_tensor * src6, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, dst); + ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, src6, dst); } break; default: { @@ -15818,9 +16006,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm bool masked = t != 0; ggml_compute_forward_flash_attn_back(params, masked, tensor); } break; + case GGML_OP_SSM_CONV: + { + ggml_compute_forward_ssm_conv(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_SSM_SCAN: { - ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor); + ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor->src[6], tensor); } break; case GGML_OP_WIN_PART: { @@ -16868,6 +17060,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // not supported } break; + case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: { GGML_ASSERT(false); // TODO: not implemented @@ -17569,6 +17762,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = n_threads; } break; + case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: { n_tasks = n_threads; diff --git a/ggml.h b/ggml.h index fdf251911..6d5cf7696 100644 --- a/ggml.h +++ b/ggml.h @@ -460,6 +460,7 @@ extern "C" { GGML_OP_FLASH_ATTN, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, + GGML_OP_SSM_CONV, GGML_OP_SSM_SCAN, GGML_OP_WIN_PART, GGML_OP_WIN_UNPART, @@ -1702,6 +1703,13 @@ extern "C" { struct ggml_tensor * c0, struct ggml_tensor * c1); + GGML_API struct ggml_tensor * ggml_ssm_conv( + struct ggml_context * ctx, + struct ggml_tensor * s, + struct ggml_tensor * x, + struct ggml_tensor * c, + struct ggml_tensor * sq); + GGML_API struct ggml_tensor * ggml_ssm_scan( struct ggml_context * ctx, struct ggml_tensor * s, @@ -1709,7 +1717,8 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C); + struct ggml_tensor * C, + struct ggml_tensor * sq); // partition into non-overlapping windows with padding if needed // example: diff --git a/llama.cpp b/llama.cpp index 2613340cc..ad0322674 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2044,6 +2044,7 @@ struct llama_context { struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba) + struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch] #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; @@ -4761,8 +4762,9 @@ static bool llm_load_tensors( const int64_t d_conv = hparams.n_embd_head_k + 1; const int64_t d_state = hparams.n_embd_head_v; const int64_t d_inner = hparams.n_head; - // FIXME: ceiling instead of floor - const int64_t dt_rank = n_embd / 16; + // TODO: allow loading dt_rank from the model config + // ceiling division + const int64_t dt_rank = (n_embd / 16) + (n_embd % 16 > 0); GGML_ASSERT(2 * n_embd == d_inner); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -8012,13 +8014,12 @@ struct llm_build_context { GGML_ASSERT(2 * d_model == d_inner); const int64_t d_conv = n_embd_head_k + 1; const int64_t d_state = n_embd_head_v; - const int64_t dt_rank = d_model / 16; + // ceiling division + const int64_t dt_rank = (d_model / 16) + (d_model % 16 > 0); struct ggml_tensor * cur; struct ggml_tensor * inpL; - GGML_ASSERT(kv_self.used - kv_self.head + 1 == 1); // TODO: support more than one sequence per batch - // {n_embd, n_tokens} inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); cb(inpL, "inp_embd", -1); @@ -8040,8 +8041,8 @@ struct llm_build_context { state_mask); } - struct ggml_tensor * conv_state = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv); - struct ggml_tensor * ssm_state = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv); + conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv); + ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv); // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -8056,37 +8057,31 @@ struct llm_build_context { struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0); + // conv { - // concat last (d_conv - 1) columns of conv_state, and x + // Custom operator which is needed only to ease simultaneous sequence processing. + // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, + // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weigth, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // The new conv_states is the last (d_conv - 1) columns + // of the last 3rd dimensional "layer" of the self-overlapping view. + // For simultaneous sequences, it's more complicated. + struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq); - // The following tensor is too big in order to avoid an assertion error when making an overlapping view. - // TODO: in ggml_new_tensor_impl, handle overlapping data range in data size calculation - // This could then be a tensor with ne[] = {(d_conv-1)+n_tokens, d_inner}, - // but the size difference is not that big (d_conv is usually 4). - struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tokens); - const size_t conv_x_nb1 = (d_conv - 1 + n_tokens) * ggml_element_size(conv_x); - - conv_x = ggml_set_2d(ctx0, conv_x, conv_state, conv_x_nb1, 0); - // making x contiguous is necessary because ggml_set expects it - conv_x = ggml_set_2d(ctx0, conv_x, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_x_nb1, (d_conv - 1)*ggml_element_size(conv_x)); - - // store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state + // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tokens*ggml_element_size(conv_x)), - ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_x)))); + ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), + ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); - // prepare convolution for all tokens in the batch with a self-overlapping view, - // shifting by one column each ... depth? ... with a window of d_conv columns. - // {(d_conv-1)+n_tokens, d_inner} => {d_conv, d_inner, n_tokens} - conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tokens, conv_x_nb1, 1*ggml_element_size(conv_x), 0); - - // perform convolution - // => {1, d_inner, n_tokens} - x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_x, model.layers[il].ssm_conv1d)); - // => {d_inner, n_tokens, 1} - x = ggml_permute(ctx0, x, 2, 0, 1, 3); + // extract x from x_conv + x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); // bias x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); @@ -8111,13 +8106,13 @@ struct llm_build_context { // as described in the Annex D of the Mamba paper. // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined, // because only a single tensor can be returned. - struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B, C); + struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); // store last states (the second part of y_ssm_states) ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_state)))); + ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_states)))); struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); @@ -8362,7 +8357,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; - // For Transformers, use only the previous KV cells + // For Transformers, use only the previous KV cells (or all, when non-causal) // of the correct sequence for each token of the batch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. for (int h = 0; h < 1; ++h) { @@ -8382,13 +8377,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } - // For Mamba (and other constant-time-and-size architectures), - // update the correct state(s)/sequence(s) for each token of the batch. - // Source and destination states are both the same for the sake of implementation simplicity. - // It would be more complex if they were sometimes the same and somtimes not. - // (with Transformers, source KV cells are never the destination, - // which is also simpler, but more memory hungry) - // TODO: implement } if (hparams.need_kq_pos) { @@ -8447,28 +8435,54 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } if (kv_self.unlimited) { - const uint32_t kv_size = kv_self.size; - const uint32_t n_kv = kv_self.n; + const int64_t kv_size = kv_self.size; + const int64_t n_kv = kv_self.n; - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); - float * data = (float *) lctx.inp_s_mask->data; + { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); + float * data = (float *) lctx.inp_s_mask->data; - // states which are not affected by the current batch are left untouched - for (uint32_t i = 0; i < n_kv; ++i) { - llama_seq_id seq_id = i + lctx.kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; - bool has_self_seq = kv_cell.has_seq_id(seq_id); + // states which are not affected by the current batch are left untouched + for (int i = 0; i < n_kv; ++i) { + llama_seq_id seq_id = i + lctx.kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; + bool has_self_seq = kv_cell.has_seq_id(seq_id); - data[i] = (float) has_self_seq; + data[i] = (float) has_self_seq; - // ensure current sequences will be kept - if (!has_self_seq) { - kv_cell.seq_id.insert(seq_id); + // ensure current sequences will be kept + if (!has_self_seq) { + kv_cell.seq_id.insert(seq_id); + } + } + } + // For Mamba (and other constant-time-and-size architectures), + // update the correct state(s)/sequence(s) for each token of the batch. + // Like with the KQ_mask, if a token in the batch has multiple sequences, + // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). + { + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); + int32_t * data = (int32_t *) lctx.inp_s_seq->data; + + for (int j = 0; j < n_tokens; ++j) { + const int32_t n_seq = batch.n_seq_id[j]; + GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence + + for (int i = 0; i < n_kv; ++i) { + if (i < n_seq) { + // for this type of model, the head is the minimum seq_id of the batch + data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head; + } else { + data[j*n_kv + i] = -1; + } + } } } // remove extraneous seq_ids when state copies are made { - for (uint32_t i = 0; i < kv_size; ++i) { + for (int i = 0; i < kv_size; ++i) { llama_kv_cell & kv_cell = lctx.kv_self.cells[i]; uint32_t n_seqs = kv_cell.seq_id.size(); bool has_self_seq = kv_cell.has_seq_id(i); @@ -12642,7 +12656,7 @@ struct llama_context * llama_new_context_with_model( // graph inputs { ggml_init_params init_params = { - /* .mem_size */ ggml_tensor_overhead()*(8 + ctx->kv_self.unlimited), + /* .mem_size */ ggml_tensor_overhead()*(8 + 2*(ctx->kv_self.unlimited)), /* .mem_buffer */ nullptr, /* .no_alloc */ true, }; @@ -12656,8 +12670,10 @@ struct llama_context * llama_new_context_with_model( ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size); ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); - if (ctx->kv_self.unlimited) + if (ctx->kv_self.unlimited) { ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size); + ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch); + } ggml_set_name(ctx->inp_tokens, "inp_tokens"); ggml_set_name(ctx->inp_embd, "inp_embd"); @@ -12667,8 +12683,10 @@ struct llama_context * llama_new_context_with_model( ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); ggml_set_name(ctx->inp_mean, "inp_mean"); ggml_set_name(ctx->inp_cls, "inp_cls"); - if (ctx->kv_self.unlimited) + if (ctx->kv_self.unlimited) { ggml_set_name(ctx->inp_s_mask, "inp_s_mask"); + ggml_set_name(ctx->inp_s_seq, "inp_s_seq"); + } ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,