ggml : avoid multiply by D in GGML_OP_SSM_SCAN
This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks
This commit is contained in:
parent
7d16e1bc8c
commit
3bc7103d2e
7 changed files with 98 additions and 95 deletions
|
@ -1828,7 +1828,6 @@ extern "C" {
|
|||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C,
|
||||
struct ggml_tensor * D,
|
||||
struct ggml_tensor * ids);
|
||||
|
||||
// partition into non-overlapping windows with padding if needed
|
||||
|
|
|
@ -1649,25 +1649,21 @@ static void ggml_metal_encode_node(
|
|||
struct ggml_tensor * src4 = node->src[4];
|
||||
struct ggml_tensor * src5 = node->src[5];
|
||||
struct ggml_tensor * src6 = node->src[6];
|
||||
struct ggml_tensor * src7 = node->src[7];
|
||||
|
||||
GGML_ASSERT(src3);
|
||||
GGML_ASSERT(src4);
|
||||
GGML_ASSERT(src5);
|
||||
GGML_ASSERT(src6);
|
||||
GGML_ASSERT(src7);
|
||||
|
||||
size_t offs_src3 = 0;
|
||||
size_t offs_src4 = 0;
|
||||
size_t offs_src5 = 0;
|
||||
size_t offs_src6 = 0;
|
||||
size_t offs_src7 = 0;
|
||||
|
||||
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
||||
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
||||
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
||||
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
|
||||
id<MTLBuffer> id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil;
|
||||
|
||||
const int64_t ne30 = src3->ne[0];
|
||||
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
||||
|
@ -1699,10 +1695,6 @@ static void ggml_metal_encode_node(
|
|||
|
||||
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
|
||||
|
||||
const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70);
|
||||
|
||||
const uint64_t nb70 = src7->nb[0]; GGML_UNUSED(nb70);
|
||||
|
||||
const int64_t d_state = ne00;
|
||||
const int64_t d_inner = ne01;
|
||||
const int64_t n_head = ne02;
|
||||
|
@ -1727,31 +1719,30 @@ static void ggml_metal_encode_node(
|
|||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
||||
[encoder setBuffer:id_src7 offset:offs_src7 atIndex:7];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:8];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
||||
|
||||
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:9];
|
||||
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10];
|
||||
[encoder setBytes:&n_head length:sizeof(n_head) atIndex:11];
|
||||
[encoder setBytes:&n_group length:sizeof(n_group) atIndex:12];
|
||||
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13];
|
||||
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14];
|
||||
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:8];
|
||||
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:9];
|
||||
[encoder setBytes:&n_head length:sizeof(n_head) atIndex:10];
|
||||
[encoder setBytes:&n_group length:sizeof(n_group) atIndex:11];
|
||||
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:12];
|
||||
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:13];
|
||||
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21];
|
||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22];
|
||||
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23];
|
||||
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
|
||||
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
|
||||
[encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26];
|
||||
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
|
||||
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
|
||||
[encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20];
|
||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21];
|
||||
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
|
||||
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23];
|
||||
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24];
|
||||
[encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25];
|
||||
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26];
|
||||
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27];
|
||||
[encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28];
|
||||
// NOTE: max index is 31
|
||||
|
||||
if (ne30 == 1) {
|
||||
|
|
|
@ -805,7 +805,6 @@ kernel void kernel_ssm_scan_f32(
|
|||
device const void * src4,
|
||||
device const void * src5,
|
||||
device const void * src6,
|
||||
device const void * src7,
|
||||
device float * dst,
|
||||
constant int64_t & d_state,
|
||||
constant int64_t & d_inner,
|
||||
|
@ -838,7 +837,6 @@ kernel void kernel_ssm_scan_f32(
|
|||
const uint64_t nb00 = sizeof(float);
|
||||
const uint64_t nb10 = sizeof(float);
|
||||
const uint64_t nb20 = sizeof(float);
|
||||
const uint64_t nb60 = sizeof(float);
|
||||
|
||||
const int64_t nc = d_state;
|
||||
const int64_t nr = d_inner;
|
||||
|
@ -848,7 +846,7 @@ kernel void kernel_ssm_scan_f32(
|
|||
|
||||
const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float);
|
||||
|
||||
device const int32_t * ids = (device const int32_t *) src7;
|
||||
device const int32_t * ids = (device const int32_t *) src6;
|
||||
|
||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
|
||||
device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off);
|
||||
|
@ -859,7 +857,6 @@ kernel void kernel_ssm_scan_f32(
|
|||
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh}
|
||||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
|
||||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
|
||||
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
|
||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
||||
|
||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||
|
@ -873,7 +870,7 @@ kernel void kernel_ssm_scan_f32(
|
|||
s[i] = state;
|
||||
}
|
||||
|
||||
y[0] = sumf + x[0] * D[0];
|
||||
y[0] = sumf;
|
||||
|
||||
// recurse
|
||||
s0 = s;
|
||||
|
@ -890,7 +887,6 @@ kernel void kernel_ssm_scan_f32_group(
|
|||
device const void * src4,
|
||||
device const void * src5,
|
||||
device const void * src6,
|
||||
device const void * src7,
|
||||
device float * dst,
|
||||
constant int64_t & d_state,
|
||||
constant int64_t & d_inner,
|
||||
|
@ -923,7 +919,6 @@ kernel void kernel_ssm_scan_f32_group(
|
|||
const uint64_t nb00 = sizeof(float);
|
||||
const uint64_t nb10 = sizeof(float);
|
||||
const uint64_t nb20 = sizeof(float);
|
||||
const uint64_t nb60 = sizeof(float);
|
||||
|
||||
const int64_t nc = d_state;
|
||||
const int64_t nr = d_inner;
|
||||
|
@ -933,7 +928,7 @@ kernel void kernel_ssm_scan_f32_group(
|
|||
|
||||
const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float);
|
||||
|
||||
device const int32_t * ids = (device const int32_t *) src7;
|
||||
device const int32_t * ids = (device const int32_t *) src6;
|
||||
|
||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
|
||||
device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off);
|
||||
|
@ -944,7 +939,6 @@ kernel void kernel_ssm_scan_f32_group(
|
|||
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh}
|
||||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
|
||||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
|
||||
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
|
||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
||||
|
||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||
|
@ -959,7 +953,7 @@ kernel void kernel_ssm_scan_f32_group(
|
|||
s[i] = state;
|
||||
}
|
||||
|
||||
y[0] = sumf + x[0] * D[0];
|
||||
y[0] = sumf;
|
||||
|
||||
// recurse
|
||||
s0 = s;
|
||||
|
|
|
@ -7181,7 +7181,6 @@ struct ggml_tensor * ggml_ssm_conv(
|
|||
const int64_t n_s = sx->ne[2];
|
||||
|
||||
// TODO: maybe support other strides than 1?
|
||||
// FIXME: this is always true?
|
||||
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
||||
GGML_ASSERT(sx->ne[1] == d_inner);
|
||||
GGML_ASSERT(n_t >= 0);
|
||||
|
@ -7205,7 +7204,6 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C,
|
||||
struct ggml_tensor * D,
|
||||
struct ggml_tensor * ids) {
|
||||
GGML_ASSERT(ggml_is_contiguous(s));
|
||||
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||
|
@ -7235,8 +7233,6 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
GGML_ASSERT(B->ne[0] == d_state);
|
||||
GGML_ASSERT(B->ne[2] == n_seq_tokens);
|
||||
GGML_ASSERT(B->ne[3] == n_seqs);
|
||||
GGML_ASSERT(D->ne[0] == n_head);
|
||||
GGML_ASSERT(ggml_is_vector(D));
|
||||
GGML_ASSERT(ids->ne[0] == n_seqs);
|
||||
GGML_ASSERT(ggml_is_vector(ids));
|
||||
GGML_ASSERT(A->ne[1] == n_head);
|
||||
|
@ -7258,8 +7254,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
result->src[3] = A;
|
||||
result->src[4] = B;
|
||||
result->src[5] = C;
|
||||
result->src[6] = D;
|
||||
result->src[7] = ids;
|
||||
result->src[6] = ids;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -16217,8 +16212,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
||||
const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
||||
const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
||||
const struct ggml_tensor * src6 = dst->src[6]; // D {n_head}
|
||||
const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs}
|
||||
const struct ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
@ -16240,8 +16234,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src6->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src7->nb[0] == sizeof(int32_t));
|
||||
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
||||
// allows optimizing the modulo since n_group should be a power of 2
|
||||
GGML_ASSERT((ng & -ng) == ng);
|
||||
|
||||
|
@ -16252,7 +16245,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
const int ih0 = dh*ith;
|
||||
const int ih1 = MIN(ih0 + dh, nh);
|
||||
|
||||
const int32_t * ids = (const int32_t *) src7->data;
|
||||
const int32_t * ids = (const int32_t *) src6->data;
|
||||
|
||||
for (int i3 = 0; i3 < ns; ++i3) {
|
||||
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
||||
|
@ -16264,7 +16257,6 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
||||
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
||||
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
||||
const float * D = (const float *) ((const char *) src6->data); // {nh}
|
||||
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
||||
|
||||
if (src3->ne[0] == 1) {
|
||||
|
@ -16325,7 +16317,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
sumf += state * C[ig];
|
||||
s[i] = state;
|
||||
}
|
||||
y[ii] = sumf + x[ii] * D[h];
|
||||
y[ii] = sumf;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -16353,7 +16345,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
sumf += state * C[ig];
|
||||
s[i] = state;
|
||||
}
|
||||
y[ii] = sumf + x[ii] * D[h];
|
||||
y[ii] = sumf;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue