mamba : apply suggestions from code review
* mamba : remove unecessary branch for row-wise ssm_state and C multiplication It was previously done to avoid permuting when only one token is processed at a time (like when generating text), but permuting is cheap, and dynamically changing the compute graph is not future-proof. * ggml : in ggml_ssm_scan, use more appropriate asserts * ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32
This commit is contained in:
parent
c52fb3c2de
commit
6ff34da092
2 changed files with 11 additions and 17 deletions
13
ggml.c
13
ggml.c
|
@ -6093,8 +6093,8 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||||
GGML_ASSERT(ggml_is_contiguous(dt));
|
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||||
GGML_ASSERT(ggml_is_contiguous(A));
|
GGML_ASSERT(ggml_is_contiguous(A));
|
||||||
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
||||||
ggml_are_same_shape(x, dt);
|
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
||||||
GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D
|
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
|
||||||
|
|
||||||
{
|
{
|
||||||
const int64_t d_state = s->ne[0];
|
const int64_t d_state = s->ne[0];
|
||||||
|
@ -6111,6 +6111,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
|
if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
|
||||||
|
GGML_ASSERT(false); // TODO: implement
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14681,7 +14682,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
|
|
||||||
// first batch
|
// first batch
|
||||||
{
|
{
|
||||||
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
|
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
|
||||||
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
|
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
|
||||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
|
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
|
||||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
|
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
|
||||||
|
@ -14695,14 +14696,14 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
for (int i0 = 0; i0 < nc; ++i0) {
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
int i = i0 + i1*nc;
|
int i = i0 + i1*nc;
|
||||||
// ssm_state * dA + dB * x
|
// ssm_state * dA + dB * x
|
||||||
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute state for rest of tokens, previous state comes from dest
|
// compute state for rest of tokens, previous state comes from dest
|
||||||
for (int i2 = 1; i2 < n_t; ++i2) {
|
for (int i2 = 1; i2 < n_t; ++i2) {
|
||||||
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
|
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
|
||||||
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
|
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
|
||||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
|
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
|
||||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
|
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
|
||||||
|
@ -14716,7 +14717,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
for (int i0 = 0; i0 < nc; ++i0) {
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
int i = i0 + i1*nc;
|
int i = i0 + i1*nc;
|
||||||
// ssm_state * dA + dB * x
|
// ssm_state * dA + dB * x
|
||||||
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
15
llama.cpp
15
llama.cpp
|
@ -8010,17 +8010,10 @@ struct llm_build_context {
|
||||||
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]),
|
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]),
|
||||||
ggml_view_tensor(ctx0, kv_self.v_l[il])));
|
ggml_view_tensor(ctx0, kv_self.v_l[il])));
|
||||||
|
|
||||||
struct ggml_tensor * y;
|
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
|
||||||
if (n_tok == 1) {
|
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
|
||||||
// row-wise dot product ("dn,n->d")
|
// => {d_inner, n_tok}
|
||||||
// {d_state, d_inner} * {d_state, 1} => {d_inner, 1}
|
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
|
||||||
y = ggml_mul_mat(ctx0, ssm_state, C);
|
|
||||||
} else {
|
|
||||||
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
|
|
||||||
y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
|
|
||||||
// => {d_inner, n_tok}
|
|
||||||
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
|
|
||||||
}
|
|
||||||
// {d_inner, n_tok} * {d_inner} => {d_inner, n_tok}
|
// {d_inner, n_tok} * {d_inner} => {d_inner, n_tok}
|
||||||
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
|
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
|
||||||
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
|
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue