mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator
This increases performance on CPU by around 30% for prompt processing, and by around 20% for text generation. However, it also makes the ggml_exp and ggml_soft_plus operators unused. Whether or not they should be kept will be decided later.
This commit is contained in:
parent
5816ae687e
commit
a3f4a1c7dc
3 changed files with 88 additions and 82 deletions
115
ggml.c
115
ggml.c
|
@ -6156,31 +6156,45 @@ struct ggml_tensor * ggml_flash_attn_back(
|
||||||
struct ggml_tensor * ggml_ssm_scan(
|
struct ggml_tensor * ggml_ssm_scan(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * s,
|
struct ggml_tensor * s,
|
||||||
struct ggml_tensor * dA,
|
struct ggml_tensor * x,
|
||||||
struct ggml_tensor * dB_x) {
|
struct ggml_tensor * dt,
|
||||||
GGML_ASSERT(ggml_are_same_shape(dA, dB_x));
|
struct ggml_tensor * A,
|
||||||
|
struct ggml_tensor * B) {
|
||||||
GGML_ASSERT( s->nb[0] == ggml_type_size( s->type));
|
GGML_ASSERT(ggml_is_contiguous(s));
|
||||||
GGML_ASSERT( dA->nb[0] == ggml_type_size( dA->type));
|
GGML_ASSERT(ggml_is_contiguous(x));
|
||||||
GGML_ASSERT(dB_x->nb[0] == ggml_type_size(dB_x->type));
|
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(A));
|
||||||
GGML_ASSERT(s->ne[0] == dA->ne[0]);
|
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
||||||
GGML_ASSERT(s->ne[1] == dA->ne[1]);
|
ggml_are_same_shape(x, dt);
|
||||||
GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D
|
GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D
|
||||||
|
|
||||||
|
{
|
||||||
|
const int64_t d_state = s->ne[0];
|
||||||
|
const int64_t d_inner = s->ne[1];
|
||||||
|
const int64_t n_tok = x->ne[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(x->ne[0] == d_inner);
|
||||||
|
GGML_ASSERT(A->ne[0] == d_state);
|
||||||
|
GGML_ASSERT(A->ne[1] == d_inner);
|
||||||
|
GGML_ASSERT(B->ne[0] == d_state);
|
||||||
|
GGML_ASSERT(B->ne[1] == n_tok);
|
||||||
|
}
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (s->grad || dA->grad || dB_x->grad) {
|
if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, dA->ne[0], dA->ne[1], dA->ne[2]);
|
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, s->ne[0], s->ne[1], x->ne[1]);
|
||||||
|
|
||||||
result->op = GGML_OP_SSM_SCAN;
|
result->op = GGML_OP_SSM_SCAN;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = s;
|
result->src[0] = s;
|
||||||
result->src[1] = dA;
|
result->src[1] = x;
|
||||||
result->src[2] = dB_x;
|
result->src[2] = dt;
|
||||||
|
result->src[3] = A;
|
||||||
|
result->src[4] = B;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -14795,9 +14809,11 @@ static void ggml_compute_forward_flash_attn_back(
|
||||||
|
|
||||||
static void ggml_compute_forward_ssm_scan_f32(
|
static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0, // s
|
||||||
const struct ggml_tensor * src1,
|
const struct ggml_tensor * src1, // x
|
||||||
const struct ggml_tensor * src2,
|
const struct ggml_tensor * src2, // dt
|
||||||
|
const struct ggml_tensor * src3, // A
|
||||||
|
const struct ggml_tensor * src4, // B
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
return;
|
return;
|
||||||
|
@ -14806,18 +14822,19 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
const int64_t nc = src1->ne[0];
|
const int64_t nc = src0->ne[0];
|
||||||
const int64_t n_t = src1->ne[2]; // number of tokens in the batch
|
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
|
||||||
const int64_t nr0 = ggml_nrows(src0);
|
const int64_t nr0 = ggml_nrows(src0);
|
||||||
|
|
||||||
GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(src1));
|
GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(dst));
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||||
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||||
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||||
// allow merging multiple rows in the same vec operation
|
// allow merging multiple rows in the same vec operation
|
||||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||||
GGML_ASSERT(src1->nb[1] == src1->ne[0]*sizeof(float));
|
GGML_ASSERT(src3->nb[1] == src3->ne[0]*sizeof(float));
|
||||||
GGML_ASSERT(src2->nb[1] == src2->ne[0]*sizeof(float));
|
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr0 + nth - 1)/nth;
|
const int dr = (nr0 + nth - 1)/nth;
|
||||||
|
@ -14829,22 +14846,44 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
|
|
||||||
// first batch
|
// first batch
|
||||||
{
|
{
|
||||||
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]));
|
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
|
||||||
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]));
|
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
|
||||||
float * dA = (float *) ((char *) src1->data + ir0*(src1->nb[1]));
|
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
|
||||||
float * dB_x = (float *) ((char *) src2->data + ir0*(src2->nb[1]));
|
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
|
||||||
ggml_vec_mul_f32(nc*ir, dest, s, dA);
|
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||||
ggml_vec_add_f32(nc*ir, dest, dest, dB_x);
|
float * B = (float *) ((char *) src4->data); // {d_state, n_tok}
|
||||||
|
// d_inner
|
||||||
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
|
float dt_soft_plus = log1pf(expf(dt[i1]));
|
||||||
|
float x_dt = x[i1] * dt_soft_plus;
|
||||||
|
// d_state
|
||||||
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
|
int i = i0 + i1*nc;
|
||||||
|
// ssm_state * dA + dB * x
|
||||||
|
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute state for rest of tokens, previous state comes from dest
|
// compute state for rest of tokens, previous state comes from dest
|
||||||
for (int i2 = 1; i2 < n_t; i2++) {
|
for (int i2 = 1; i2 < n_t; ++i2) {
|
||||||
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2]));
|
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
|
||||||
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2]));
|
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
|
||||||
float * dA = (float *) ((char *) src1->data + ir0*(src1->nb[1]) + i2 *(src1->nb[2]));
|
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
|
||||||
float * dB_x = (float *) ((char *) src2->data + ir0*(src2->nb[1]) + i2 *(src2->nb[2]));
|
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
|
||||||
ggml_vec_mul_f32(nc*ir, dest, s, dA);
|
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||||
ggml_vec_add_f32(nc*ir, dest, dest, dB_x);
|
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tok}
|
||||||
|
// d_inner
|
||||||
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
|
float dt_soft_plus = log1pf(expf(dt[i1]));
|
||||||
|
float x_dt = x[i1] * dt_soft_plus;
|
||||||
|
// d_state
|
||||||
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
|
int i = i0 + i1*nc;
|
||||||
|
// ssm_state * dA + dB * x
|
||||||
|
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14853,11 +14892,13 @@ static void ggml_compute_forward_ssm_scan(
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
const struct ggml_tensor * src1,
|
const struct ggml_tensor * src1,
|
||||||
const struct ggml_tensor * src2,
|
const struct ggml_tensor * src2,
|
||||||
|
const struct ggml_tensor * src3,
|
||||||
|
const struct ggml_tensor * src4,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, dst);
|
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
@ -15927,7 +15968,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
|
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_WIN_PART:
|
case GGML_OP_WIN_PART:
|
||||||
{
|
{
|
||||||
|
|
6
ggml.h
6
ggml.h
|
@ -1724,8 +1724,10 @@ extern "C" {
|
||||||
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * s,
|
struct ggml_tensor * s,
|
||||||
struct ggml_tensor * dA,
|
struct ggml_tensor * x,
|
||||||
struct ggml_tensor * dB_x);
|
struct ggml_tensor * dt,
|
||||||
|
struct ggml_tensor * A,
|
||||||
|
struct ggml_tensor * B);
|
||||||
|
|
||||||
// partition into non-overlapping windows with padding if needed
|
// partition into non-overlapping windows with padding if needed
|
||||||
// example:
|
// example:
|
||||||
|
|
45
llama.cpp
45
llama.cpp
|
@ -8005,49 +8005,12 @@ struct llm_build_context {
|
||||||
// {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok}
|
// {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok}
|
||||||
dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
|
dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
|
||||||
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
|
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
|
||||||
dt = ggml_soft_plus(ctx0, dt);
|
|
||||||
|
|
||||||
struct ggml_tensor * dA;
|
// Custom operator to implement some of the optimizations
|
||||||
struct ggml_tensor * dB;
|
// described in the Annex D of the Mamba paper.
|
||||||
if (n_tok == 1) {
|
// TODO: maybe also optimize step 4 of the Speed section of Annex D (the mul_mat with C)
|
||||||
// => {d_state, d_inner}
|
|
||||||
dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt)));
|
|
||||||
|
|
||||||
// {d_state} * {d_inner} => {d_state, d_inner}
|
|
||||||
dB = ggml_out_prod(ctx0, B, dt);
|
|
||||||
} else {
|
|
||||||
// {d_state, d_inner} * {d_inner, n_tok} => {d_state, d_inner, n_tok} * {1, d_inner, n_tok}
|
|
||||||
// => {d_state, d_inner, n_tok}
|
// => {d_state, d_inner, n_tok}
|
||||||
// Trying to do the equivalent of
|
ssm_state = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B);
|
||||||
// dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
|
|
||||||
struct ggml_tensor * A = model.layers[il].ssm_a;
|
|
||||||
dA = ggml_exp(ctx0,
|
|
||||||
ggml_mul(ctx0,
|
|
||||||
ggml_repeat(ctx0, A, ggml_new_tensor_3d(ctx0, A->type, d_state, d_inner, n_tok)),
|
|
||||||
// {d_inner, n_tok} => {1, d_inner, n_tok}
|
|
||||||
ggml_permute(ctx0, dt, 1, 2, 0, 3))
|
|
||||||
);
|
|
||||||
|
|
||||||
// {d_state, 1, n_tok} * {d_inner, 1, n_tok} => {d_state, d_inner, n_tok}
|
|
||||||
dB = ggml_out_prod(ctx0,
|
|
||||||
// {d_state, n_tok} => {d_state, 1, n_tok}
|
|
||||||
ggml_permute(ctx0, B, 0, 2, 1, 3),
|
|
||||||
// {d_state, n_tok} => {d_state, 1, n_tok}
|
|
||||||
ggml_permute(ctx0, dt, 0, 2, 1, 3));
|
|
||||||
}
|
|
||||||
|
|
||||||
// {d_state, d_inner, n_tok} * {1, d_inner, n_tok} => {d_state, d_inner, n_tok}
|
|
||||||
cur = ggml_mul(ctx0, dB, ggml_permute(ctx0, x, 1, 2, 0, 3));
|
|
||||||
|
|
||||||
// The selective scan seems inherently sequential...
|
|
||||||
// To avoid making (n_layer * n_tok) graph nodes, let's use a custom operator.
|
|
||||||
// When n_tok == 1, it's equivalent to the following:
|
|
||||||
// ssm_state = ggml_add(ctx0, ggml_mul(ctx0, ssm_state, dA), cur);
|
|
||||||
// When n_tok is bigger, it's the same thing, but iterated n_tok times,
|
|
||||||
// with the correct dA and cur for each token.
|
|
||||||
// The resulting states are layered on the ne[2] dimension.
|
|
||||||
// => {d_state, d_inner, n_tok}
|
|
||||||
ssm_state = ggml_ssm_scan(ctx0, ssm_state, dA, cur);
|
|
||||||
|
|
||||||
// only store last state
|
// only store last state
|
||||||
ggml_build_forward_expand(gf,
|
ggml_build_forward_expand(gf,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue