llama : avoid redundant state copy for Mamba 1 and 2

This commit is contained in:
Francis Couture-Harpin 2024-09-30 15:52:42 -04:00
parent 0e601cafe9
commit 273e7a495a
4 changed files with 142 additions and 119 deletions

View file

@ -1833,7 +1833,8 @@ extern "C" {
struct ggml_tensor * A,
struct ggml_tensor * B,
struct ggml_tensor * C,
struct ggml_tensor * D);
struct ggml_tensor * D,
struct ggml_tensor * ids);
// partition into non-overlapping windows with padding if needed
// example:

View file

@ -7598,7 +7598,8 @@ struct ggml_tensor * ggml_ssm_scan(
struct ggml_tensor * A,
struct ggml_tensor * B,
struct ggml_tensor * C,
struct ggml_tensor * D) {
struct ggml_tensor * D,
struct ggml_tensor * ids) {
GGML_ASSERT(ggml_is_contiguous(s));
GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A));
@ -7609,6 +7610,7 @@ struct ggml_tensor * ggml_ssm_scan(
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
GGML_ASSERT(ggml_are_same_shape(B, C));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
{
const int64_t d_state = s->ne[0];
@ -7623,21 +7625,19 @@ struct ggml_tensor * ggml_ssm_scan(
GGML_ASSERT(ggml_is_3d(dt));
GGML_ASSERT(s->ne[1] == head_dim);
GGML_ASSERT(s->ne[2] == n_head);
GGML_ASSERT(s->ne[3] == n_seqs);
GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[2] == n_seq_tokens);
GGML_ASSERT(B->ne[3] == n_seqs);
GGML_ASSERT(D->ne[0] == n_head);
GGML_ASSERT(ggml_is_vector(D));
GGML_ASSERT(ids->ne[0] == n_seqs);
GGML_ASSERT(ggml_is_vector(ids));
GGML_ASSERT(A->ne[1] == n_head);
GGML_ASSERT(ggml_is_matrix(A));
if (ggml_is_vector(A)) {
// Mamba-2
GGML_ASSERT(A->ne[0] == n_head);
} else {
// Mamba-1
if (A->ne[0] != 1) {
// Mamba-1 has more granular decay factors
GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == n_head);
GGML_ASSERT(ggml_is_matrix(A));
}
}
@ -7649,7 +7649,7 @@ struct ggml_tensor * ggml_ssm_scan(
}
// concatenated y + ssm_states
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
result->op = GGML_OP_SSM_SCAN;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -7660,6 +7660,7 @@ struct ggml_tensor * ggml_ssm_scan(
result->src[4] = B;
result->src[5] = C;
result->src[6] = D;
result->src[7] = ids;
return result;
}
@ -16635,13 +16636,14 @@ static void ggml_compute_forward_ssm_conv(
static void ggml_compute_forward_ssm_scan_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs}
const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head}
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
const struct ggml_tensor * src6 = dst->src[6]; // D {n_head}
const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs}
const int ith = params->ith;
const int nth = params->nth;
@ -16651,11 +16653,12 @@ static void ggml_compute_forward_ssm_scan_f32(
const int64_t nh = src1->ne[1]; // n_head
const int64_t ng = src4->ne[1];
const int64_t nt = src1->ne[2]; // number of tokens per sequence
const int64_t ns = src0->ne[3]; // number of sequences in the batch
const int64_t ns = src1->ne[3]; // number of sequences in the batch
const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1);
// can't use ggml_nbytes because src1 is not necessarily contiguous
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
@ -16663,6 +16666,7 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(float));
GGML_ASSERT(src7->nb[0] == sizeof(int32_t));
// allows optimizing the modulo since n_group should be a power of 2
GGML_ASSERT((ng & -ng) == ng);
@ -16673,22 +16677,22 @@ static void ggml_compute_forward_ssm_scan_f32(
const int ih0 = dh*ith;
const int ih1 = MIN(ih0 + dh, nh);
const int32_t * ids = (const int32_t *) src7->data;
for (int i3 = 0; i3 < ns; ++i3) {
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
for (int i2 = 0; i2 < nt; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns}
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh}
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
const float * D = (const float *) ((const char *) src6->data); // {nh}
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
// use the output as the source when it's not the first token-wise iteration
if (i2 > 0) { s0 = s; }
if (ggml_is_vector(src3)) {
if (src3->ne[0] == 1) {
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
// n_head
@ -16778,6 +16782,8 @@ static void ggml_compute_forward_ssm_scan_f32(
}
}
}
// use the output as the source when it's not the first token-wise iteration
s0 = s;
}
}
}