llama : avoid redundant state copy for Mamba 1 and 2
This commit is contained in:
parent
0e601cafe9
commit
273e7a495a
4 changed files with 142 additions and 119 deletions
|
@ -1833,7 +1833,8 @@ extern "C" {
|
|||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C,
|
||||
struct ggml_tensor * D);
|
||||
struct ggml_tensor * D,
|
||||
struct ggml_tensor * ids);
|
||||
|
||||
// partition into non-overlapping windows with padding if needed
|
||||
// example:
|
||||
|
|
|
@ -7598,7 +7598,8 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C,
|
||||
struct ggml_tensor * D) {
|
||||
struct ggml_tensor * D,
|
||||
struct ggml_tensor * ids) {
|
||||
GGML_ASSERT(ggml_is_contiguous(s));
|
||||
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||
GGML_ASSERT(ggml_is_contiguous(A));
|
||||
|
@ -7609,6 +7610,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
|
||||
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
|
||||
GGML_ASSERT(ggml_are_same_shape(B, C));
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
{
|
||||
const int64_t d_state = s->ne[0];
|
||||
|
@ -7623,21 +7625,19 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
GGML_ASSERT(ggml_is_3d(dt));
|
||||
GGML_ASSERT(s->ne[1] == head_dim);
|
||||
GGML_ASSERT(s->ne[2] == n_head);
|
||||
GGML_ASSERT(s->ne[3] == n_seqs);
|
||||
GGML_ASSERT(B->ne[0] == d_state);
|
||||
GGML_ASSERT(B->ne[2] == n_seq_tokens);
|
||||
GGML_ASSERT(B->ne[3] == n_seqs);
|
||||
GGML_ASSERT(D->ne[0] == n_head);
|
||||
GGML_ASSERT(ggml_is_vector(D));
|
||||
GGML_ASSERT(ids->ne[0] == n_seqs);
|
||||
GGML_ASSERT(ggml_is_vector(ids));
|
||||
GGML_ASSERT(A->ne[1] == n_head);
|
||||
GGML_ASSERT(ggml_is_matrix(A));
|
||||
|
||||
if (ggml_is_vector(A)) {
|
||||
// Mamba-2
|
||||
GGML_ASSERT(A->ne[0] == n_head);
|
||||
} else {
|
||||
// Mamba-1
|
||||
if (A->ne[0] != 1) {
|
||||
// Mamba-1 has more granular decay factors
|
||||
GGML_ASSERT(A->ne[0] == d_state);
|
||||
GGML_ASSERT(A->ne[1] == n_head);
|
||||
GGML_ASSERT(ggml_is_matrix(A));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7649,7 +7649,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
}
|
||||
|
||||
// concatenated y + ssm_states
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
|
||||
|
||||
result->op = GGML_OP_SSM_SCAN;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
|
@ -7660,6 +7660,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
result->src[4] = B;
|
||||
result->src[5] = C;
|
||||
result->src[6] = D;
|
||||
result->src[7] = ids;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -16635,13 +16636,14 @@ static void ggml_compute_forward_ssm_conv(
|
|||
static void ggml_compute_forward_ssm_scan_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs}
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
|
||||
const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
|
||||
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head}
|
||||
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
||||
const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
||||
const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
||||
const struct ggml_tensor * src6 = dst->src[6]; // D {n_head}
|
||||
const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
@ -16651,11 +16653,12 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
const int64_t nh = src1->ne[1]; // n_head
|
||||
const int64_t ng = src4->ne[1];
|
||||
const int64_t nt = src1->ne[2]; // number of tokens per sequence
|
||||
const int64_t ns = src0->ne[3]; // number of sequences in the batch
|
||||
const int64_t ns = src1->ne[3]; // number of sequences in the batch
|
||||
|
||||
const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1);
|
||||
// can't use ggml_nbytes because src1 is not necessarily contiguous
|
||||
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
|
||||
|
||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
|
@ -16663,6 +16666,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src6->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src7->nb[0] == sizeof(int32_t));
|
||||
// allows optimizing the modulo since n_group should be a power of 2
|
||||
GGML_ASSERT((ng & -ng) == ng);
|
||||
|
||||
|
@ -16673,22 +16677,22 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
const int ih0 = dh*ith;
|
||||
const int ih1 = MIN(ih0 + dh, nh);
|
||||
|
||||
const int32_t * ids = (const int32_t *) src7->data;
|
||||
|
||||
for (int i3 = 0; i3 < ns; ++i3) {
|
||||
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
||||
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
||||
|
||||
for (int i2 = 0; i2 < nt; ++i2) {
|
||||
const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns}
|
||||
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
|
||||
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
|
||||
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh}
|
||||
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
||||
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
||||
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
||||
const float * D = (const float *) ((const char *) src6->data); // {nh}
|
||||
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
||||
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
||||
|
||||
// use the output as the source when it's not the first token-wise iteration
|
||||
if (i2 > 0) { s0 = s; }
|
||||
|
||||
if (ggml_is_vector(src3)) {
|
||||
if (src3->ne[0] == 1) {
|
||||
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
|
||||
|
||||
// n_head
|
||||
|
@ -16778,6 +16782,8 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
}
|
||||
}
|
||||
}
|
||||
// use the output as the source when it's not the first token-wise iteration
|
||||
s0 = s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue