mamba : reduce memory usage of ggml_ssm_scan
From 290.37 MiB to 140.68 MiB of CPU compute buffer size with Mamba 3B with a batch size of 512. The result tensor of ggml_ssm_scan was previously a big part of the CPU compute buffer size. To make it smaller, it does not contain the intermediate ssm states anymore. Both y and the last ssm state are combined in the result tensor, because it seems only a single tensor can be returned by an operator with the way the graph is built.
This commit is contained in:
parent
e73eaa7b4f
commit
de50c549c4
3 changed files with 70 additions and 50 deletions
72
ggml.c
72
ggml.c
|
@ -6087,14 +6087,15 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
struct ggml_tensor * x,
|
||||
struct ggml_tensor * dt,
|
||||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B) {
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C) {
|
||||
GGML_ASSERT(ggml_is_contiguous(s));
|
||||
GGML_ASSERT(ggml_is_contiguous(x));
|
||||
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||
GGML_ASSERT(ggml_is_contiguous(A));
|
||||
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
||||
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
||||
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
||||
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
|
||||
|
||||
{
|
||||
const int64_t d_state = s->ne[0];
|
||||
|
@ -6106,6 +6107,8 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
GGML_ASSERT(A->ne[1] == d_inner);
|
||||
GGML_ASSERT(B->ne[0] == d_state);
|
||||
GGML_ASSERT(B->ne[1] == n_tokens);
|
||||
GGML_ASSERT(C->ne[0] == d_state);
|
||||
GGML_ASSERT(C->ne[1] == n_tokens);
|
||||
}
|
||||
|
||||
bool is_node = false;
|
||||
|
@ -6115,7 +6118,8 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, s->ne[0], s->ne[1], x->ne[1]);
|
||||
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
||||
|
||||
result->op = GGML_OP_SSM_SCAN;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
|
@ -6124,6 +6128,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
result->src[2] = dt;
|
||||
result->src[3] = A;
|
||||
result->src[4] = B;
|
||||
result->src[5] = C;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -14650,6 +14655,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
const struct ggml_tensor * src2, // dt
|
||||
const struct ggml_tensor * src3, // A
|
||||
const struct ggml_tensor * src4, // B
|
||||
const struct ggml_tensor * src5, // C
|
||||
struct ggml_tensor * dst) {
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
|
@ -14658,67 +14664,84 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t nc = src0->ne[0];
|
||||
const int64_t nc = src0->ne[0]; // d_state
|
||||
const int64_t nr = src0->ne[1]; // d_inner
|
||||
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
|
||||
const int64_t nr0 = ggml_nrows(src0);
|
||||
|
||||
GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(dst));
|
||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||
// allow merging multiple rows in the same vec operation
|
||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||
// required for the dot product between s and C
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||
GGML_ASSERT(src3->nb[1] == src3->ne[0]*sizeof(float));
|
||||
// required to get correct offset for state destination
|
||||
GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr0 + nth - 1)/nth;
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr0);
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
const int ir = ir1 - ir0;
|
||||
|
||||
// first batch
|
||||
// first token in the batch
|
||||
{
|
||||
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tokens}
|
||||
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
|
||||
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv}
|
||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner, n_kv}
|
||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
|
||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
|
||||
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||
float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
|
||||
float * C = (float *) ((char *) src5->data); // {d_state, n_tokens}
|
||||
// d_inner
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
float dt_soft_plus = log1pf(expf(dt[i1]));
|
||||
float x_dt = x[i1] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
// d_state
|
||||
for (int i0 = 0; i0 < nc; ++i0) {
|
||||
int i = i0 + i1*nc;
|
||||
// ssm_state * dA + dB * x
|
||||
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||
// state = prev_state * dA + dB * x
|
||||
float state = s0[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||
// y = rowwise_dotprod(state, C)
|
||||
sumf += state*C[i0];
|
||||
// FIXME: handle simultaneous sequences
|
||||
s[i] = state;
|
||||
}
|
||||
y[i1] = sumf;
|
||||
}
|
||||
}
|
||||
|
||||
// compute state for rest of tokens, previous state comes from dest
|
||||
// rest of the batch, state comes from previous one which was stored in destination
|
||||
for (int i2 = 1; i2 < n_t; ++i2) {
|
||||
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tokens}
|
||||
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tokens}
|
||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tokens}
|
||||
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv}
|
||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
|
||||
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
|
||||
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
|
||||
// d_inner
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
float dt_soft_plus = log1pf(expf(dt[i1]));
|
||||
float x_dt = x[i1] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
// d_state
|
||||
for (int i0 = 0; i0 < nc; ++i0) {
|
||||
int i = i0 + i1*nc;
|
||||
// ssm_state * dA + dB * x
|
||||
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||
// state = prev_state * dA + dB * x
|
||||
float state = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||
// y = rowwise_dotprod(state, C)
|
||||
sumf += state*C[i0];
|
||||
// FIXME: handle simultaneous sequences
|
||||
s[i] = state;
|
||||
}
|
||||
y[i1] = sumf;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14730,11 +14753,12 @@ static void ggml_compute_forward_ssm_scan(
|
|||
const struct ggml_tensor * src2,
|
||||
const struct ggml_tensor * src3,
|
||||
const struct ggml_tensor * src4,
|
||||
const struct ggml_tensor * src5,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, dst);
|
||||
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
@ -15796,7 +15820,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
} break;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
|
||||
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor);
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
{
|
||||
|
|
3
ggml.h
3
ggml.h
|
@ -1708,7 +1708,8 @@ extern "C" {
|
|||
struct ggml_tensor * x,
|
||||
struct ggml_tensor * dt,
|
||||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B);
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C);
|
||||
|
||||
// partition into non-overlapping windows with padding if needed
|
||||
// example:
|
||||
|
|
31
llama.cpp
31
llama.cpp
|
@ -8029,9 +8029,9 @@ struct llm_build_context {
|
|||
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], (d_conv-1)*(d_inner), kv_self.size);
|
||||
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], (d_state)*(d_inner), kv_self.size);
|
||||
|
||||
// clear states of sequences which are starting at the beginning of this batch
|
||||
{
|
||||
ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
|
||||
// clear states of sequences which are starting at the beginning of this batch
|
||||
conv_states = ggml_mul(ctx0,
|
||||
ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
|
||||
state_mask);
|
||||
|
@ -8040,11 +8040,8 @@ struct llm_build_context {
|
|||
state_mask);
|
||||
}
|
||||
|
||||
// TODO: support more than one sequence per batch (these could then use ggml_reshape_3d)
|
||||
ggml_tensor * conv_state = ggml_view_2d(ctx0, conv_states, d_conv - 1, d_inner,
|
||||
(d_conv - 1)*ggml_element_size(conv_states), 0);
|
||||
ggml_tensor * ssm_state = ggml_view_2d(ctx0, ssm_states, d_state, d_inner,
|
||||
(d_state)*ggml_element_size(ssm_states), 0);
|
||||
struct ggml_tensor * conv_state = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
|
||||
struct ggml_tensor * ssm_state = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv);
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
|
@ -8110,22 +8107,20 @@ struct llm_build_context {
|
|||
dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
|
||||
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
|
||||
|
||||
// Custom operator to implement some of the optimizations
|
||||
// described in the Annex D of the Mamba paper.
|
||||
// TODO: maybe also optimize step 4 of the Speed section of Annex D (the mul_mat with C)
|
||||
// => {d_state, d_inner, n_tokens}
|
||||
ssm_state = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B);
|
||||
// Custom operator to optimize the parallel associative scan
|
||||
// as described in the Annex D of the Mamba paper.
|
||||
// => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
|
||||
// because only a single tensor can be returned.
|
||||
struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B, C);
|
||||
|
||||
// only store last state
|
||||
// store last states (the second part of y_ssm_states)
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0,
|
||||
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tokens-1)*ssm_state->nb[2]),
|
||||
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner, kv_self.head*d_state*d_inner*ggml_element_size(ssm_state))));
|
||||
ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
|
||||
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_state))));
|
||||
|
||||
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
|
||||
|
||||
// {d_state, d_inner, n_tokens} * {d_state, n_tokens} => {d_inner, 1, n_tokens}
|
||||
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
|
||||
// => {d_inner, n_tokens}
|
||||
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
|
||||
// {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
|
||||
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
|
||||
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue