llama : initial Mamba-2 support

This commit is contained in:
Francis Couture-Harpin 2024-08-01 10:43:42 -04:00
parent a1631e53f6
commit 1f0fea70fb
7 changed files with 490 additions and 82 deletions

View file

@ -1787,7 +1787,8 @@ extern "C" {
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
struct ggml_tensor * C);
struct ggml_tensor * C,
struct ggml_tensor * D);
// partition into non-overlapping windows with padding if needed
// example:

View file

@ -7270,32 +7270,48 @@ struct ggml_tensor * ggml_ssm_scan(
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
struct ggml_tensor * C) {
struct ggml_tensor * C,
struct ggml_tensor * D) {
GGML_ASSERT(ggml_is_contiguous(s));
GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A));
GGML_ASSERT(ggml_is_matrix(A));
GGML_ASSERT(ggml_is_3d(B));
GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
GGML_ASSERT(ggml_are_same_shape(x, dt));
GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
GGML_ASSERT(ggml_are_same_shape(B, C));
{
const int64_t d_state = s->ne[0];
const int64_t d_inner = s->ne[1];
const int64_t n_seq_tokens = x->ne[1];
const int64_t n_seqs = x->ne[2];
const int64_t head_dim = x->ne[0];
const int64_t n_head = x->ne[1];
const int64_t n_seq_tokens = x->ne[2];
const int64_t n_seqs = x->ne[3];
GGML_ASSERT(s->ne[2] == n_seqs);
GGML_ASSERT(x->ne[0] == d_inner);
GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == d_inner);
GGML_ASSERT(dt->ne[0] == n_head);
GGML_ASSERT(dt->ne[1] == n_seq_tokens);
GGML_ASSERT(dt->ne[2] == n_seqs);
GGML_ASSERT(ggml_is_3d(dt));
GGML_ASSERT(s->ne[1] == head_dim);
GGML_ASSERT(s->ne[2] == n_head);
GGML_ASSERT(s->ne[3] == n_seqs);
GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[1] == n_seq_tokens);
GGML_ASSERT(B->ne[2] == n_seqs);
GGML_ASSERT(B->ne[2] == n_seq_tokens);
GGML_ASSERT(B->ne[3] == n_seqs);
GGML_ASSERT(D->ne[0] == n_head);
GGML_ASSERT(ggml_is_vector(D));
if (ggml_is_vector(A)) {
// Mamba-2
GGML_ASSERT(A->ne[0] == n_head);
} else {
// Mamba-1
GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == n_head);
GGML_ASSERT(ggml_is_matrix(A));
}
}
bool is_node = false;
@ -7316,6 +7332,7 @@ struct ggml_tensor * ggml_ssm_scan(
result->src[3] = A;
result->src[4] = B;
result->src[5] = C;
result->src[6] = D;
return result;
}
@ -15840,20 +15857,25 @@ static void ggml_compute_forward_ssm_conv(
static void ggml_compute_forward_ssm_scan_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // s
const struct ggml_tensor * src1 = dst->src[1]; // x
const struct ggml_tensor * src2 = dst->src[2]; // dt
const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C
const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs}
const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head}
const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
const struct ggml_tensor * src6 = dst->src[6]; // D {n_head}
const int ith = params->ith;
const int nth = params->nth;
const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // dim
const int64_t nh = src1->ne[1]; // n_head
const int64_t ng = src4->ne[1];
const int64_t nt = src1->ne[2]; // number of tokens per sequence
const int64_t ns = src0->ne[3]; // number of sequences in the batch
const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1);
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
@ -15862,51 +15884,86 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(float));
// allows optimizing the modulo since n_group should be a power of 2
GGML_ASSERT((ng & -ng) == ng);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// heads per thread
const int dh = (nh + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0;
// head range for this thread
const int ih0 = dh*ith;
const int ih1 = MIN(ih0 + dh, nh);
for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
for (int i3 = 0; i3 < ns; ++i3) {
for (int i2 = 0; i2 < nt; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns}
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
const float * D = (const float *) ((const char *) src6->data); // {nh}
float * y = (float *) ((char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
float * s = (float *) ((char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
// use the output as the source for the next token-wise iterations
// use the output as the source when it's not the first token-wise iteration
if (i2 > 0) { s0 = s; }
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
float x_dt = x[i1] * dt_soft_plus;
float sumf = 0.0f;
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// state = prev_state * dA + dB * x
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[i0];
s[i] = state;
if (ggml_is_vector(src3)) {
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
const float dA = expf(dt_soft_plus * A[h]);
// TODO: SIMD implementation
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int i = i1 + h*nr;
const float x_dt = x[i] * dt_soft_plus;
float sumf = 0.0f;
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
const int ii = i0 + i*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x
const float state = (s0[ii] * dA) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[ig];
s[ii] = state;
}
y[i] = sumf + x[i] * D[h];
}
}
} else {
// Mamba-1 has an element-wise decay factor for the states
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int i = i1 + h*nr;
const float x_dt = x[i] * dt_soft_plus;
float sumf = 0.0f;
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
const int ii = i0 + i*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x
const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[ig];
s[ii] = state;
}
y[i] = sumf + x[i] * D[h];
}
}
y[i1] = sumf;
}
}
}