mamba : handle batches of more than 1 token

This means running Mamba no longer crashes when using the default settings!
And probably also slightly faster prompt processing.
Both batched and non-batched processing yield the same output.

Previously, the state was not cleared when starting a sequence.
Next step is to make the KV cache API work as expected for Mamba models.

* ggml: add ggml_ssm_scan to help with parallel selective scan

If the selective scan was implemented without a custom operator,
there would be waaay too many nodes in the graph. For example,
for Mamba-130M, with a batch size of 512 (the default),
a naive selective scan could add at least 24*512=12288 nodes,
which is more than LLAMA_MAX_NODES (8192),
and that's only for the smallest Mamba model.
So it's much cleaner with a custom operator.
Not sure about the name, though.
This commit is contained in:
Francis Couture-Harpin 2024-01-31 20:45:04 -05:00
parent 81b57bb375
commit ffc116f5ec
3 changed files with 196 additions and 19 deletions

124
ggml.c
View file

@ -1831,6 +1831,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"FLASH_ATTN", "FLASH_ATTN",
"FLASH_FF", "FLASH_FF",
"FLASH_ATTN_BACK", "FLASH_ATTN_BACK",
"SSM_SCAN",
"WIN_PART", "WIN_PART",
"WIN_UNPART", "WIN_UNPART",
"GET_REL_POS", "GET_REL_POS",
@ -1853,7 +1854,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK", "CROSS_ENTROPY_LOSS_BACK",
}; };
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -1919,6 +1920,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"flash_attn(x)", "flash_attn(x)",
"flash_ff(x)", "flash_ff(x)",
"flash_attn_back(x)", "flash_attn_back(x)",
"ssm_scan(x)",
"win_part(x)", "win_part(x)",
"win_unpart(x)", "win_unpart(x)",
"get_rel_pos(x)", "get_rel_pos(x)",
@ -1941,7 +1943,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)", "cross_entropy_loss_back(x,y)",
}; };
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -6149,6 +6151,40 @@ struct ggml_tensor * ggml_flash_attn_back(
return result; return result;
} }
// ggml_ssm_scan
struct ggml_tensor * ggml_ssm_scan(
struct ggml_context * ctx,
struct ggml_tensor * s,
struct ggml_tensor * dA,
struct ggml_tensor * dB_x) {
GGML_ASSERT(ggml_are_same_shape(dA, dB_x));
GGML_ASSERT( s->nb[0] == ggml_type_size( s->type));
GGML_ASSERT( dA->nb[0] == ggml_type_size( dA->type));
GGML_ASSERT(dB_x->nb[0] == ggml_type_size(dB_x->type));
GGML_ASSERT(s->ne[0] == dA->ne[0]);
GGML_ASSERT(s->ne[1] == dA->ne[1]);
GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D
bool is_node = false;
if (s->grad || dA->grad || dB_x->grad) {
is_node = true;
}
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, dA->ne[0], dA->ne[1], dA->ne[2]);
result->op = GGML_OP_SSM_SCAN;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = s;
result->src[1] = dA;
result->src[2] = dB_x;
return result;
}
// ggml_win_part // ggml_win_part
struct ggml_tensor * ggml_win_part( struct ggml_tensor * ggml_win_part(
@ -14755,6 +14791,78 @@ static void ggml_compute_forward_flash_attn_back(
} }
} }
// ggml_compute_forward_ssm_scan
static void ggml_compute_forward_ssm_scan_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
const int64_t nc = src1->ne[0];
const int64_t n_b = src1->ne[2]; // number of batches
const int64_t nr0 = ggml_nrows(src0);
GGML_ASSERT(nc*n_b*nr0 == ggml_nelements(src1));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
// rows per thread
const int dr = (nr0 + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr0);
// first batch
for (int i1 = ir0; i1 < ir1; i1++) {
float * dest = (float *) ((char *) dst->data + i1*( dst->nb[1]));
float * s = (float *) ((char *) src0->data + i1*(src0->nb[1]));
float * dA = (float *) ((char *) src1->data + i1*(src1->nb[1]));
float * dB_x = (float *) ((char *) src2->data + i1*(src2->nb[1]));
ggml_vec_mul_f32(nc, dest, s, dA);
ggml_vec_add_f32(nc, dest, dest, dB_x);
}
// rest of batches, state comes from dest
for (int i2 = 1; i2 < n_b; i2++) {
for (int i1 = ir0; i1 < ir1; i1++) {
float * dest = (float *) ((char *) dst->data + i1*( dst->nb[1]) + i2 *( dst->nb[2]));
float * s = (float *) ((char *) dst->data + i1*( dst->nb[1]) + (i2-1)*( dst->nb[2]));
float * dA = (float *) ((char *) src1->data + i1*(src1->nb[1]) + i2 *(src1->nb[2]));
float * dB_x = (float *) ((char *) src2->data + i1*(src2->nb[1]) + i2 *(src2->nb[2]));
ggml_vec_mul_f32(nc, dest, s, dA);
ggml_vec_add_f32(nc, dest, dest, dB_x);
}
}
}
static void ggml_compute_forward_ssm_scan(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_win_part // ggml_compute_forward_win_part
static void ggml_compute_forward_win_part_f32( static void ggml_compute_forward_win_part_f32(
@ -15814,6 +15922,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
bool masked = t != 0; bool masked = t != 0;
ggml_compute_forward_flash_attn_back(params, masked, tensor); ggml_compute_forward_flash_attn_back(params, masked, tensor);
} break; } break;
case GGML_OP_SSM_SCAN:
{
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} break;
case GGML_OP_WIN_PART: case GGML_OP_WIN_PART:
{ {
ggml_compute_forward_win_part(params, tensor); ggml_compute_forward_win_part(params, tensor);
@ -16868,6 +16980,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{ {
GGML_ASSERT(false); // not supported GGML_ASSERT(false); // not supported
} break; } break;
case GGML_OP_SSM_SCAN:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_WIN_PART: case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART: case GGML_OP_WIN_UNPART:
case GGML_OP_UNARY: case GGML_OP_UNARY:
@ -17570,6 +17686,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
{ {
n_tasks = n_threads; n_tasks = n_threads;
} break; } break;
case GGML_OP_SSM_SCAN:
{
n_tasks = n_threads;
} break;
case GGML_OP_WIN_PART: case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART: case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS: case GGML_OP_GET_REL_POS:

7
ggml.h
View file

@ -462,6 +462,7 @@ extern "C" {
GGML_OP_FLASH_ATTN, GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF, GGML_OP_FLASH_FF,
GGML_OP_FLASH_ATTN_BACK, GGML_OP_FLASH_ATTN_BACK,
GGML_OP_SSM_SCAN,
GGML_OP_WIN_PART, GGML_OP_WIN_PART,
GGML_OP_WIN_UNPART, GGML_OP_WIN_UNPART,
GGML_OP_GET_REL_POS, GGML_OP_GET_REL_POS,
@ -1720,6 +1721,12 @@ extern "C" {
struct ggml_tensor * c0, struct ggml_tensor * c0,
struct ggml_tensor * c1); struct ggml_tensor * c1);
GGML_API struct ggml_tensor * ggml_ssm_scan(
struct ggml_context * ctx,
struct ggml_tensor * s,
struct ggml_tensor * dA,
struct ggml_tensor * dB_x);
// partition into non-overlapping windows with padding if needed // partition into non-overlapping windows with padding if needed
// example: // example:
// a: 768 64 64 1 // a: 768 64 64 1

View file

@ -7935,6 +7935,13 @@ struct llm_build_context {
ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv - 1, d_inner); ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv - 1, d_inner);
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner); ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);
// reset the states when starting a new sequence
// TODO: ensure kv_self clearing is handled
if (!batch.pos || batch.pos[0] == 0) {
conv_state = ggml_scale(ctx0, conv_state, 0);
ssm_state = ggml_scale(ctx0, ssm_state, 0);
}
// norm // norm
cur = llm_build_norm(ctx0, inpL, hparams, cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL, model.layers[il].attn_norm, NULL,
@ -7991,36 +7998,79 @@ struct llm_build_context {
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tok} => {dt_rank + 2*d_state, n_tok} // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tok} => {dt_rank + 2*d_state, n_tok}
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x); struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
// split // split
struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, x_db->ne[1], x_db->nb[1], 0); struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tok, x_db->nb[1], 0);
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1], ggml_element_size(x_db)*dt_rank); struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
// {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok} // {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok}
dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt); dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
dt = ggml_soft_plus(ctx0, dt); dt = ggml_soft_plus(ctx0, dt);
// FIXME: support batches with more than 1 token struct ggml_tensor * dA;
struct ggml_tensor * dB;
if (n_tok == 1) {
// => {d_state, d_inner} // => {d_state, d_inner}
struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt))); dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt)));
// => {d_state, d_inner} // {d_state} * {d_inner} => {d_state, d_inner}
struct ggml_tensor * dB = ggml_out_prod(ctx0, B, dt); dB = ggml_out_prod(ctx0, B, dt);
} else {
// {d_state, d_inner} * {d_inner, n_tok} => {d_state, d_inner, n_tok} * {1, d_inner, n_tok}
// => {d_state, d_inner, n_tok}
// Trying to do the equivalent of
// dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
struct ggml_tensor * A = model.layers[il].ssm_a;
dA = ggml_exp(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, A, ggml_new_tensor_3d(ctx0, A->type, d_state, d_inner, n_tok)),
// {d_inner, n_tok} => {1, d_inner, n_tok}
ggml_permute(ctx0, dt, 1, 2, 0, 3))
);
// => {d_state, d_inner} // {d_state, 1, n_tok} * {d_inner, 1, n_tok} => {d_state, d_inner, n_tok}
cur = ggml_mul(ctx0, dB, ggml_transpose(ctx0, x)); dB = ggml_out_prod(ctx0,
// {d_state, n_tok} => {d_state, 1, n_tok}
ggml_permute(ctx0, B, 0, 2, 1, 3),
// {d_state, n_tok} => {d_state, 1, n_tok}
ggml_permute(ctx0, dt, 0, 2, 1, 3));
}
ssm_state = ggml_add(ctx0, ggml_mul(ctx0, ssm_state, dA), cur); // {d_state, d_inner, n_tok} * {1, d_inner, n_tok} => {d_state, d_inner, n_tok}
cur = ggml_mul(ctx0, dB, ggml_permute(ctx0, x, 1, 2, 0, 3));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_state, ggml_view_tensor(ctx0, kv_self.v_l[il]))); // The selective scan seems inherently sequential...
// To avoid making (n_layer * n_tok) graph nodes, let's use a custom operator.
// When n_tok == 1, it's equivalent to the following:
// ssm_state = ggml_add(ctx0, ggml_mul(ctx0, ssm_state, dA), cur);
// When n_tok is bigger, it's the same thing, but iterated n_tok times,
// with the correct dA and cur for each token.
// The resulting states are layered on the ne[2] dimension.
// => {d_state, d_inner, n_tok}
ssm_state = ggml_ssm_scan(ctx0, ssm_state, dA, cur);
// only store last state
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]),
ggml_view_tensor(ctx0, kv_self.v_l[il])));
struct ggml_tensor * y;
if (n_tok == 1) {
// row-wise dot product ("dn,n->d") // row-wise dot product ("dn,n->d")
// {d_state, d_inner} * {d_state} => {d_inner, 1} // {d_state, d_inner} * {d_state, 1} => {d_inner, 1}
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, C); y = ggml_mul_mat(ctx0, ssm_state, C);
y = ggml_add(ctx0, y, ggml_mul(ctx0, model.layers[il].ssm_d, x)); } else {
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
// => {d_inner, n_tok}
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
}
// {d_inner, n_tok} * {d_inner} => {d_inner, n_tok}
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
// {d_inner, n_embd} * {d_inner, 1} => {n_embd, 1} // {d_inner, n_embd} * {d_inner, n_tok} => {n_embd, n_tok}
cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y); cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
} }