mamba : handle batches of more than 1 token
This means running Mamba no longer crashes when using the default settings! And probably also slightly faster prompt processing. Both batched and non-batched processing yield the same output. Previously, the state was not cleared when starting a sequence. Next step is to make the KV cache API work as expected for Mamba models. * ggml: add ggml_ssm_scan to help with parallel selective scan If the selective scan was implemented without a custom operator, there would be waaay too many nodes in the graph. For example, for Mamba-130M, with a batch size of 512 (the default), a naive selective scan could add at least 24*512=12288 nodes, which is more than LLAMA_MAX_NODES (8192), and that's only for the smallest Mamba model. So it's much cleaner with a custom operator. Not sure about the name, though.
This commit is contained in:
parent
81b57bb375
commit
ffc116f5ec
3 changed files with 196 additions and 19 deletions
124
ggml.c
124
ggml.c
|
@ -1831,6 +1831,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"FLASH_ATTN",
|
||||
"FLASH_FF",
|
||||
"FLASH_ATTN_BACK",
|
||||
"SSM_SCAN",
|
||||
"WIN_PART",
|
||||
"WIN_UNPART",
|
||||
"GET_REL_POS",
|
||||
|
@ -1853,7 +1854,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
@ -1919,6 +1920,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"flash_attn(x)",
|
||||
"flash_ff(x)",
|
||||
"flash_attn_back(x)",
|
||||
"ssm_scan(x)",
|
||||
"win_part(x)",
|
||||
"win_unpart(x)",
|
||||
"get_rel_pos(x)",
|
||||
|
@ -1941,7 +1943,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"cross_entropy_loss_back(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
@ -6149,6 +6151,40 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_ssm_scan
|
||||
|
||||
struct ggml_tensor * ggml_ssm_scan(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * s,
|
||||
struct ggml_tensor * dA,
|
||||
struct ggml_tensor * dB_x) {
|
||||
GGML_ASSERT(ggml_are_same_shape(dA, dB_x));
|
||||
|
||||
GGML_ASSERT( s->nb[0] == ggml_type_size( s->type));
|
||||
GGML_ASSERT( dA->nb[0] == ggml_type_size( dA->type));
|
||||
GGML_ASSERT(dB_x->nb[0] == ggml_type_size(dB_x->type));
|
||||
|
||||
GGML_ASSERT(s->ne[0] == dA->ne[0]);
|
||||
GGML_ASSERT(s->ne[1] == dA->ne[1]);
|
||||
GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (s->grad || dA->grad || dB_x->grad) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, dA->ne[0], dA->ne[1], dA->ne[2]);
|
||||
|
||||
result->op = GGML_OP_SSM_SCAN;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = s;
|
||||
result->src[1] = dA;
|
||||
result->src[2] = dB_x;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_win_part
|
||||
|
||||
struct ggml_tensor * ggml_win_part(
|
||||
|
@ -14755,6 +14791,78 @@ static void ggml_compute_forward_flash_attn_back(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_ssm_scan
|
||||
|
||||
static void ggml_compute_forward_ssm_scan_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
const struct ggml_tensor * src2,
|
||||
struct ggml_tensor * dst) {
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t nc = src1->ne[0];
|
||||
const int64_t n_b = src1->ne[2]; // number of batches
|
||||
const int64_t nr0 = ggml_nrows(src0);
|
||||
|
||||
GGML_ASSERT(nc*n_b*nr0 == ggml_nelements(src1));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr0 + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr0);
|
||||
|
||||
// first batch
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
float * dest = (float *) ((char *) dst->data + i1*( dst->nb[1]));
|
||||
float * s = (float *) ((char *) src0->data + i1*(src0->nb[1]));
|
||||
float * dA = (float *) ((char *) src1->data + i1*(src1->nb[1]));
|
||||
float * dB_x = (float *) ((char *) src2->data + i1*(src2->nb[1]));
|
||||
ggml_vec_mul_f32(nc, dest, s, dA);
|
||||
ggml_vec_add_f32(nc, dest, dest, dB_x);
|
||||
}
|
||||
|
||||
// rest of batches, state comes from dest
|
||||
for (int i2 = 1; i2 < n_b; i2++) {
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
float * dest = (float *) ((char *) dst->data + i1*( dst->nb[1]) + i2 *( dst->nb[2]));
|
||||
float * s = (float *) ((char *) dst->data + i1*( dst->nb[1]) + (i2-1)*( dst->nb[2]));
|
||||
float * dA = (float *) ((char *) src1->data + i1*(src1->nb[1]) + i2 *(src1->nb[2]));
|
||||
float * dB_x = (float *) ((char *) src2->data + i1*(src2->nb[1]) + i2 *(src2->nb[2]));
|
||||
ggml_vec_mul_f32(nc, dest, s, dA);
|
||||
ggml_vec_add_f32(nc, dest, dest, dB_x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_ssm_scan(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
const struct ggml_tensor * src2,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_win_part
|
||||
|
||||
static void ggml_compute_forward_win_part_f32(
|
||||
|
@ -15814,6 +15922,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
bool masked = t != 0;
|
||||
ggml_compute_forward_flash_attn_back(params, masked, tensor);
|
||||
} break;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
{
|
||||
ggml_compute_forward_win_part(params, tensor);
|
||||
|
@ -16868,6 +16980,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
GGML_ASSERT(false); // not supported
|
||||
} break;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_UNARY:
|
||||
|
@ -17570,6 +17686,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_GET_REL_POS:
|
||||
|
|
7
ggml.h
7
ggml.h
|
@ -462,6 +462,7 @@ extern "C" {
|
|||
GGML_OP_FLASH_ATTN,
|
||||
GGML_OP_FLASH_FF,
|
||||
GGML_OP_FLASH_ATTN_BACK,
|
||||
GGML_OP_SSM_SCAN,
|
||||
GGML_OP_WIN_PART,
|
||||
GGML_OP_WIN_UNPART,
|
||||
GGML_OP_GET_REL_POS,
|
||||
|
@ -1720,6 +1721,12 @@ extern "C" {
|
|||
struct ggml_tensor * c0,
|
||||
struct ggml_tensor * c1);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * s,
|
||||
struct ggml_tensor * dA,
|
||||
struct ggml_tensor * dB_x);
|
||||
|
||||
// partition into non-overlapping windows with padding if needed
|
||||
// example:
|
||||
// a: 768 64 64 1
|
||||
|
|
84
llama.cpp
84
llama.cpp
|
@ -7935,6 +7935,13 @@ struct llm_build_context {
|
|||
ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv - 1, d_inner);
|
||||
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);
|
||||
|
||||
// reset the states when starting a new sequence
|
||||
// TODO: ensure kv_self clearing is handled
|
||||
if (!batch.pos || batch.pos[0] == 0) {
|
||||
conv_state = ggml_scale(ctx0, conv_state, 0);
|
||||
ssm_state = ggml_scale(ctx0, ssm_state, 0);
|
||||
}
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
|
@ -7991,36 +7998,79 @@ struct llm_build_context {
|
|||
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tok} => {dt_rank + 2*d_state, n_tok}
|
||||
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
|
||||
// split
|
||||
struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, x_db->ne[1], x_db->nb[1], 0);
|
||||
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1], ggml_element_size(x_db)*dt_rank);
|
||||
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
|
||||
struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tok, x_db->nb[1], 0);
|
||||
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
|
||||
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
|
||||
|
||||
// {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok}
|
||||
dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
|
||||
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
|
||||
dt = ggml_soft_plus(ctx0, dt);
|
||||
|
||||
// FIXME: support batches with more than 1 token
|
||||
// => {d_state, d_inner}
|
||||
struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt)));
|
||||
struct ggml_tensor * dA;
|
||||
struct ggml_tensor * dB;
|
||||
if (n_tok == 1) {
|
||||
// => {d_state, d_inner}
|
||||
dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt)));
|
||||
|
||||
// => {d_state, d_inner}
|
||||
struct ggml_tensor * dB = ggml_out_prod(ctx0, B, dt);
|
||||
// {d_state} * {d_inner} => {d_state, d_inner}
|
||||
dB = ggml_out_prod(ctx0, B, dt);
|
||||
} else {
|
||||
// {d_state, d_inner} * {d_inner, n_tok} => {d_state, d_inner, n_tok} * {1, d_inner, n_tok}
|
||||
// => {d_state, d_inner, n_tok}
|
||||
// Trying to do the equivalent of
|
||||
// dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
|
||||
struct ggml_tensor * A = model.layers[il].ssm_a;
|
||||
dA = ggml_exp(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, A, ggml_new_tensor_3d(ctx0, A->type, d_state, d_inner, n_tok)),
|
||||
// {d_inner, n_tok} => {1, d_inner, n_tok}
|
||||
ggml_permute(ctx0, dt, 1, 2, 0, 3))
|
||||
);
|
||||
|
||||
// => {d_state, d_inner}
|
||||
cur = ggml_mul(ctx0, dB, ggml_transpose(ctx0, x));
|
||||
// {d_state, 1, n_tok} * {d_inner, 1, n_tok} => {d_state, d_inner, n_tok}
|
||||
dB = ggml_out_prod(ctx0,
|
||||
// {d_state, n_tok} => {d_state, 1, n_tok}
|
||||
ggml_permute(ctx0, B, 0, 2, 1, 3),
|
||||
// {d_state, n_tok} => {d_state, 1, n_tok}
|
||||
ggml_permute(ctx0, dt, 0, 2, 1, 3));
|
||||
}
|
||||
|
||||
ssm_state = ggml_add(ctx0, ggml_mul(ctx0, ssm_state, dA), cur);
|
||||
// {d_state, d_inner, n_tok} * {1, d_inner, n_tok} => {d_state, d_inner, n_tok}
|
||||
cur = ggml_mul(ctx0, dB, ggml_permute(ctx0, x, 1, 2, 0, 3));
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_state, ggml_view_tensor(ctx0, kv_self.v_l[il])));
|
||||
// The selective scan seems inherently sequential...
|
||||
// To avoid making (n_layer * n_tok) graph nodes, let's use a custom operator.
|
||||
// When n_tok == 1, it's equivalent to the following:
|
||||
// ssm_state = ggml_add(ctx0, ggml_mul(ctx0, ssm_state, dA), cur);
|
||||
// When n_tok is bigger, it's the same thing, but iterated n_tok times,
|
||||
// with the correct dA and cur for each token.
|
||||
// The resulting states are layered on the ne[2] dimension.
|
||||
// => {d_state, d_inner, n_tok}
|
||||
ssm_state = ggml_ssm_scan(ctx0, ssm_state, dA, cur);
|
||||
|
||||
// row-wise dot product ("dn,n->d")
|
||||
// {d_state, d_inner} * {d_state} => {d_inner, 1}
|
||||
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, C);
|
||||
y = ggml_add(ctx0, y, ggml_mul(ctx0, model.layers[il].ssm_d, x));
|
||||
// only store last state
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0,
|
||||
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]),
|
||||
ggml_view_tensor(ctx0, kv_self.v_l[il])));
|
||||
|
||||
struct ggml_tensor * y;
|
||||
if (n_tok == 1) {
|
||||
// row-wise dot product ("dn,n->d")
|
||||
// {d_state, d_inner} * {d_state, 1} => {d_inner, 1}
|
||||
y = ggml_mul_mat(ctx0, ssm_state, C);
|
||||
} else {
|
||||
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
|
||||
y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
|
||||
// => {d_inner, n_tok}
|
||||
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
|
||||
}
|
||||
// {d_inner, n_tok} * {d_inner} => {d_inner, n_tok}
|
||||
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
|
||||
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
|
||||
|
||||
// {d_inner, n_embd} * {d_inner, 1} => {n_embd, 1}
|
||||
// {d_inner, n_embd} * {d_inner, n_tok} => {n_embd, n_tok}
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue