diff --git a/ggml.c b/ggml.c index 8f351d823..177d8b3c3 100644 --- a/ggml.c +++ b/ggml.c @@ -1831,6 +1831,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "FLASH_ATTN", "FLASH_FF", "FLASH_ATTN_BACK", + "SSM_SCAN", "WIN_PART", "WIN_UNPART", "GET_REL_POS", @@ -1853,7 +1854,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1919,6 +1920,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_attn(x)", "flash_ff(x)", "flash_attn_back(x)", + "ssm_scan(x)", "win_part(x)", "win_unpart(x)", "get_rel_pos(x)", @@ -1941,7 +1943,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6149,6 +6151,40 @@ struct ggml_tensor * ggml_flash_attn_back( return result; } +// ggml_ssm_scan + +struct ggml_tensor * ggml_ssm_scan( + struct ggml_context * ctx, + struct ggml_tensor * s, + struct ggml_tensor * dA, + struct ggml_tensor * dB_x) { + GGML_ASSERT(ggml_are_same_shape(dA, dB_x)); + + GGML_ASSERT( s->nb[0] == ggml_type_size( s->type)); + GGML_ASSERT( dA->nb[0] == ggml_type_size( dA->type)); + GGML_ASSERT(dB_x->nb[0] == ggml_type_size(dB_x->type)); + + GGML_ASSERT(s->ne[0] == dA->ne[0]); + GGML_ASSERT(s->ne[1] == dA->ne[1]); + GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D + + bool is_node = false; + + if (s->grad || dA->grad || dB_x->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, dA->ne[0], dA->ne[1], dA->ne[2]); + + result->op = GGML_OP_SSM_SCAN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = s; + result->src[1] = dA; + result->src[2] = dB_x; + + return result; +} + // ggml_win_part struct ggml_tensor * ggml_win_part( @@ -14755,6 +14791,78 @@ static void ggml_compute_forward_flash_attn_back( } } +// ggml_compute_forward_ssm_scan + +static void ggml_compute_forward_ssm_scan_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * src2, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nc = src1->ne[0]; + const int64_t n_b = src1->ne[2]; // number of batches + const int64_t nr0 = ggml_nrows(src0); + + GGML_ASSERT(nc*n_b*nr0 == ggml_nelements(src1)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + + // rows per thread + const int dr = (nr0 + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr0); + + // first batch + for (int i1 = ir0; i1 < ir1; i1++) { + float * dest = (float *) ((char *) dst->data + i1*( dst->nb[1])); + float * s = (float *) ((char *) src0->data + i1*(src0->nb[1])); + float * dA = (float *) ((char *) src1->data + i1*(src1->nb[1])); + float * dB_x = (float *) ((char *) src2->data + i1*(src2->nb[1])); + ggml_vec_mul_f32(nc, dest, s, dA); + ggml_vec_add_f32(nc, dest, dest, dB_x); + } + + // rest of batches, state comes from dest + for (int i2 = 1; i2 < n_b; i2++) { + for (int i1 = ir0; i1 < ir1; i1++) { + float * dest = (float *) ((char *) dst->data + i1*( dst->nb[1]) + i2 *( dst->nb[2])); + float * s = (float *) ((char *) dst->data + i1*( dst->nb[1]) + (i2-1)*( dst->nb[2])); + float * dA = (float *) ((char *) src1->data + i1*(src1->nb[1]) + i2 *(src1->nb[2])); + float * dB_x = (float *) ((char *) src2->data + i1*(src2->nb[1]) + i2 *(src2->nb[2])); + ggml_vec_mul_f32(nc, dest, s, dA); + ggml_vec_add_f32(nc, dest, dest, dB_x); + } + } +} + +static void ggml_compute_forward_ssm_scan( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * src2, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_win_part static void ggml_compute_forward_win_part_f32( @@ -15814,6 +15922,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm bool masked = t != 0; ggml_compute_forward_flash_attn_back(params, masked, tensor); } break; + case GGML_OP_SSM_SCAN: + { + ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } break; case GGML_OP_WIN_PART: { ggml_compute_forward_win_part(params, tensor); @@ -16868,6 +16980,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // not supported } break; + case GGML_OP_SSM_SCAN: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_UNARY: @@ -17570,6 +17686,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = n_threads; } break; + case GGML_OP_SSM_SCAN: + { + n_tasks = n_threads; + } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_GET_REL_POS: diff --git a/ggml.h b/ggml.h index efb62c598..0a40f8762 100644 --- a/ggml.h +++ b/ggml.h @@ -462,6 +462,7 @@ extern "C" { GGML_OP_FLASH_ATTN, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, + GGML_OP_SSM_SCAN, GGML_OP_WIN_PART, GGML_OP_WIN_UNPART, GGML_OP_GET_REL_POS, @@ -1720,6 +1721,12 @@ extern "C" { struct ggml_tensor * c0, struct ggml_tensor * c1); + GGML_API struct ggml_tensor * ggml_ssm_scan( + struct ggml_context * ctx, + struct ggml_tensor * s, + struct ggml_tensor * dA, + struct ggml_tensor * dB_x); + // partition into non-overlapping windows with padding if needed // example: // a: 768 64 64 1 diff --git a/llama.cpp b/llama.cpp index f064969d2..bbf16e8f4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7935,6 +7935,13 @@ struct llm_build_context { ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv - 1, d_inner); ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner); + // reset the states when starting a new sequence + // TODO: ensure kv_self clearing is handled + if (!batch.pos || batch.pos[0] == 0) { + conv_state = ggml_scale(ctx0, conv_state, 0); + ssm_state = ggml_scale(ctx0, ssm_state, 0); + } + // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, @@ -7991,36 +7998,79 @@ struct llm_build_context { // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tok} => {dt_rank + 2*d_state, n_tok} struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x); // split - struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, x_db->ne[1], x_db->nb[1], 0); - struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); + struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tok, x_db->nb[1], 0); + struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); // {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok} dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); dt = ggml_soft_plus(ctx0, dt); - // FIXME: support batches with more than 1 token - // => {d_state, d_inner} - struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt))); + struct ggml_tensor * dA; + struct ggml_tensor * dB; + if (n_tok == 1) { + // => {d_state, d_inner} + dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt))); - // => {d_state, d_inner} - struct ggml_tensor * dB = ggml_out_prod(ctx0, B, dt); + // {d_state} * {d_inner} => {d_state, d_inner} + dB = ggml_out_prod(ctx0, B, dt); + } else { + // {d_state, d_inner} * {d_inner, n_tok} => {d_state, d_inner, n_tok} * {1, d_inner, n_tok} + // => {d_state, d_inner, n_tok} + // Trying to do the equivalent of + // dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) + struct ggml_tensor * A = model.layers[il].ssm_a; + dA = ggml_exp(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, A, ggml_new_tensor_3d(ctx0, A->type, d_state, d_inner, n_tok)), + // {d_inner, n_tok} => {1, d_inner, n_tok} + ggml_permute(ctx0, dt, 1, 2, 0, 3)) + ); - // => {d_state, d_inner} - cur = ggml_mul(ctx0, dB, ggml_transpose(ctx0, x)); + // {d_state, 1, n_tok} * {d_inner, 1, n_tok} => {d_state, d_inner, n_tok} + dB = ggml_out_prod(ctx0, + // {d_state, n_tok} => {d_state, 1, n_tok} + ggml_permute(ctx0, B, 0, 2, 1, 3), + // {d_state, n_tok} => {d_state, 1, n_tok} + ggml_permute(ctx0, dt, 0, 2, 1, 3)); + } - ssm_state = ggml_add(ctx0, ggml_mul(ctx0, ssm_state, dA), cur); + // {d_state, d_inner, n_tok} * {1, d_inner, n_tok} => {d_state, d_inner, n_tok} + cur = ggml_mul(ctx0, dB, ggml_permute(ctx0, x, 1, 2, 0, 3)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_state, ggml_view_tensor(ctx0, kv_self.v_l[il]))); + // The selective scan seems inherently sequential... + // To avoid making (n_layer * n_tok) graph nodes, let's use a custom operator. + // When n_tok == 1, it's equivalent to the following: + // ssm_state = ggml_add(ctx0, ggml_mul(ctx0, ssm_state, dA), cur); + // When n_tok is bigger, it's the same thing, but iterated n_tok times, + // with the correct dA and cur for each token. + // The resulting states are layered on the ne[2] dimension. + // => {d_state, d_inner, n_tok} + ssm_state = ggml_ssm_scan(ctx0, ssm_state, dA, cur); - // row-wise dot product ("dn,n->d") - // {d_state, d_inner} * {d_state} => {d_inner, 1} - struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, C); - y = ggml_add(ctx0, y, ggml_mul(ctx0, model.layers[il].ssm_d, x)); + // only store last state + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]), + ggml_view_tensor(ctx0, kv_self.v_l[il]))); + + struct ggml_tensor * y; + if (n_tok == 1) { + // row-wise dot product ("dn,n->d") + // {d_state, d_inner} * {d_state, 1} => {d_inner, 1} + y = ggml_mul_mat(ctx0, ssm_state, C); + } else { + // {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok} + y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3)); + // => {d_inner, n_tok} + y = ggml_permute(ctx0, y, 0, 2, 1, 3); + } + // {d_inner, n_tok} * {d_inner} => {d_inner, n_tok} + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); - // {d_inner, n_embd} * {d_inner, 1} => {n_embd, 1} + // {d_inner, n_embd} * {d_inner, n_tok} => {n_embd, n_tok} cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y); }