diff --git a/ggml.c b/ggml.c index a7f55016e..2ab47216e 100644 --- a/ggml.c +++ b/ggml.c @@ -6093,8 +6093,8 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); - ggml_are_same_shape(x, dt); - GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D + GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D { const int64_t d_state = s->ne[0]; @@ -6111,6 +6111,7 @@ struct ggml_tensor * ggml_ssm_scan( bool is_node = false; if (s->grad || x->grad || dt->grad || A->grad || B->grad) { + GGML_ASSERT(false); // TODO: implement is_node = true; } @@ -14681,7 +14682,7 @@ static void ggml_compute_forward_ssm_scan_f32( // first batch { - float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok} + float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok} float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner} float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok} float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok} @@ -14695,14 +14696,14 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i0 = 0; i0 < nc; ++i0) { int i = i0 + i1*nc; // ssm_state * dA + dB * x - dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); } } } // compute state for rest of tokens, previous state comes from dest for (int i2 = 1; i2 < n_t; ++i2) { - float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok} + float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok} float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok} float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok} float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok} @@ -14716,7 +14717,7 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i0 = 0; i0 < nc; ++i0) { int i = i0 + i1*nc; // ssm_state * dA + dB * x - dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); } } } diff --git a/llama.cpp b/llama.cpp index d320d727f..fc1dd024e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8010,17 +8010,10 @@ struct llm_build_context { ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]), ggml_view_tensor(ctx0, kv_self.v_l[il]))); - struct ggml_tensor * y; - if (n_tok == 1) { - // row-wise dot product ("dn,n->d") - // {d_state, d_inner} * {d_state, 1} => {d_inner, 1} - y = ggml_mul_mat(ctx0, ssm_state, C); - } else { - // {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok} - y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3)); - // => {d_inner, n_tok} - y = ggml_permute(ctx0, y, 0, 2, 1, 3); - } + // {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok} + struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3)); + // => {d_inner, n_tok} + y = ggml_permute(ctx0, y, 0, 2, 1, 3); // {d_inner, n_tok} * {d_inner} => {d_inner, n_tok} y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));