mamba : apply suggestions from code review

* mamba : remove unecessary branch for row-wise ssm_state and C multiplication

It was previously done to avoid permuting when only one token is processed
at a time (like when generating text), but permuting is cheap,
and dynamically changing the compute graph is not future-proof.

* ggml : in ggml_ssm_scan, use more appropriate asserts

* ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32
This commit is contained in:
Francis Couture-Harpin 2024-02-05 10:13:55 -05:00
parent c52fb3c2de
commit 6ff34da092
2 changed files with 11 additions and 17 deletions

13
ggml.c
View file

@ -6093,8 +6093,8 @@ struct ggml_tensor * ggml_ssm_scan(
GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
ggml_are_same_shape(x, dt);
GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D
GGML_ASSERT(ggml_are_same_shape(x, dt));
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
{
const int64_t d_state = s->ne[0];
@ -6111,6 +6111,7 @@ struct ggml_tensor * ggml_ssm_scan(
bool is_node = false;
if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
GGML_ASSERT(false); // TODO: implement
is_node = true;
}
@ -14681,7 +14682,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// first batch
{
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
@ -14695,14 +14696,14 @@ static void ggml_compute_forward_ssm_scan_f32(
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// ssm_state * dA + dB * x
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
}
}
}
// compute state for rest of tokens, previous state comes from dest
for (int i2 = 1; i2 < n_t; ++i2) {
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
@ -14716,7 +14717,7 @@ static void ggml_compute_forward_ssm_scan_f32(
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// ssm_state * dA + dB * x
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
}
}
}

View file

@ -8010,17 +8010,10 @@ struct llm_build_context {
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]),
ggml_view_tensor(ctx0, kv_self.v_l[il])));
struct ggml_tensor * y;
if (n_tok == 1) {
// row-wise dot product ("dn,n->d")
// {d_state, d_inner} * {d_state, 1} => {d_inner, 1}
y = ggml_mul_mat(ctx0, ssm_state, C);
} else {
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
// => {d_inner, n_tok}
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
}
// {d_inner, n_tok} * {d_inner} => {d_inner, n_tok}
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));