mamba : simultaneous sequence processing
A batch can now contain tokens from multiple sequences. This is necessary for at least the parallel example, the server example, and the HellaSwag test in the perplexity example. However, for this to be useful, uses of llama_kv_cache_seq_rm/cp will need to be changed to work on whole sequences. * ggml : add ggml_ssm_conv as a new operator for the conv step of Mamba This operator makes it possible to use and update the correct states for each token of the batch in the same way as ggml_ssm_scan. Other solutions which use existing operators would need loops which would add too many nodes to the graph (at least the ones I thought of). Using this operator further reduces the size of the CPU compute buffer from 140.68 MiB to 103.20 MiB with Mamba 3B with a batch size of 512. And (at least on CPU), it's a bit faster than before. Note that "ggml_ssm_conv" is probably not the most appropriate name, and it could be changed if a better one is found. * llama : add inp_s_seq as a new input tensor The most convenient implementation to select the correct state (for Mamba) for each token is to directly get the correct index from a tensor. This is why inp_s_seq is storing int32_t and not floats. The other, less convenient way to select the correct state would be to have inp_KQ_mask contain 1.0f for each state used by a token and 0.0f otherwise. This complicates quickly fetching the first used state of a token, and is also less efficient because a whole row of the mask would always need to be read for each token. Using indexes makes it easy to stop searching when there are no more sequences for a token, and the first sequence assigned is always very quickly available (it's the first element of each row).
This commit is contained in:
parent
de50c549c4
commit
9473ec2147
3 changed files with 330 additions and 109 deletions
292
ggml.c
292
ggml.c
|
@ -1828,6 +1828,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"FLASH_ATTN",
|
"FLASH_ATTN",
|
||||||
"FLASH_FF",
|
"FLASH_FF",
|
||||||
"FLASH_ATTN_BACK",
|
"FLASH_ATTN_BACK",
|
||||||
|
"SSM_CONV",
|
||||||
"SSM_SCAN",
|
"SSM_SCAN",
|
||||||
"WIN_PART",
|
"WIN_PART",
|
||||||
"WIN_UNPART",
|
"WIN_UNPART",
|
||||||
|
@ -1851,7 +1852,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
|
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -1915,6 +1916,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"flash_attn(x)",
|
"flash_attn(x)",
|
||||||
"flash_ff(x)",
|
"flash_ff(x)",
|
||||||
"flash_attn_back(x)",
|
"flash_attn_back(x)",
|
||||||
|
"ssm_conv(x)",
|
||||||
"ssm_scan(x)",
|
"ssm_scan(x)",
|
||||||
"win_part(x)",
|
"win_part(x)",
|
||||||
"win_unpart(x)",
|
"win_unpart(x)",
|
||||||
|
@ -1938,7 +1940,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
|
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
@ -6079,6 +6081,51 @@ struct ggml_tensor * ggml_flash_attn_back(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_ssm_conv
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_ssm_conv(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * s,
|
||||||
|
struct ggml_tensor * x,
|
||||||
|
struct ggml_tensor * c,
|
||||||
|
struct ggml_tensor * sq) {
|
||||||
|
GGML_ASSERT(ggml_is_3d(s));
|
||||||
|
GGML_ASSERT(ggml_is_matrix(x));
|
||||||
|
GGML_ASSERT(ggml_is_matrix(c));
|
||||||
|
GGML_ASSERT(ggml_is_matrix(sq));
|
||||||
|
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
const int64_t d_conv = c->ne[0];
|
||||||
|
const int64_t d_inner = c->ne[1];
|
||||||
|
const int64_t n_tokens = x->ne[1];
|
||||||
|
const int64_t n_kv = s->ne[2];
|
||||||
|
|
||||||
|
GGML_ASSERT( s->ne[0] == d_conv - 1);
|
||||||
|
GGML_ASSERT( s->ne[1] == d_inner);
|
||||||
|
GGML_ASSERT( x->ne[0] == d_inner);
|
||||||
|
GGML_ASSERT(sq->ne[0] == n_kv);
|
||||||
|
GGML_ASSERT(sq->ne[1] == n_tokens);
|
||||||
|
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
if (s->grad || x->grad || c->grad || sq->grad) {
|
||||||
|
GGML_ASSERT(false); // TODO: implement
|
||||||
|
is_node = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
|
||||||
|
|
||||||
|
result->op = GGML_OP_SSM_CONV;
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src[0] = s;
|
||||||
|
result->src[1] = x;
|
||||||
|
result->src[2] = c;
|
||||||
|
result->src[3] = sq;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_ssm_scan
|
// ggml_ssm_scan
|
||||||
|
|
||||||
struct ggml_tensor * ggml_ssm_scan(
|
struct ggml_tensor * ggml_ssm_scan(
|
||||||
|
@ -6088,11 +6135,13 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||||
struct ggml_tensor * dt,
|
struct ggml_tensor * dt,
|
||||||
struct ggml_tensor * A,
|
struct ggml_tensor * A,
|
||||||
struct ggml_tensor * B,
|
struct ggml_tensor * B,
|
||||||
struct ggml_tensor * C) {
|
struct ggml_tensor * C,
|
||||||
|
struct ggml_tensor * sq) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(s));
|
GGML_ASSERT(ggml_is_contiguous(s));
|
||||||
GGML_ASSERT(ggml_is_contiguous(x));
|
GGML_ASSERT(ggml_is_contiguous(x));
|
||||||
GGML_ASSERT(ggml_is_contiguous(dt));
|
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||||
GGML_ASSERT(ggml_is_contiguous(A));
|
GGML_ASSERT(ggml_is_contiguous(A));
|
||||||
|
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
||||||
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
||||||
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
||||||
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
||||||
|
@ -6113,7 +6162,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
|
if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
|
||||||
GGML_ASSERT(false); // TODO: implement
|
GGML_ASSERT(false); // TODO: implement
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
@ -6129,6 +6178,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||||
result->src[3] = A;
|
result->src[3] = A;
|
||||||
result->src[4] = B;
|
result->src[4] = B;
|
||||||
result->src[5] = C;
|
result->src[5] = C;
|
||||||
|
result->src[6] = sq;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -14646,6 +14696,135 @@ static void ggml_compute_forward_flash_attn_back(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_ssm_conv
|
||||||
|
|
||||||
|
static void ggml_compute_forward_ssm_conv_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * src0, // conv_state
|
||||||
|
const struct ggml_tensor * src1, // x
|
||||||
|
const struct ggml_tensor * src2, // conv1d.weight
|
||||||
|
const struct ggml_tensor * src3, // state_seq
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src2->ne[0]; // d_conv
|
||||||
|
const int nr = src0->ne[1]; // d_inner
|
||||||
|
const int n_t = src1->ne[1]; // n_tokens
|
||||||
|
const int n_kv = src0->ne[2]; // max number of sequences in the batch
|
||||||
|
|
||||||
|
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
|
||||||
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||||
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||||
|
GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
|
||||||
|
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||||
|
// for use with the destination state offset between sequences
|
||||||
|
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
const int ir = ir1 - ir0;
|
||||||
|
|
||||||
|
if (n_kv > 1) {
|
||||||
|
// multiple sequences means it's hard to know when it's the first time a state is read,
|
||||||
|
// so copy them all over to the destination, just to be sure.
|
||||||
|
for (int i3 = 0; i3 < n_kv; ++i3) {
|
||||||
|
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
||||||
|
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
|
||||||
|
// can't use memcpy because of d_conv vs d_conv - 1
|
||||||
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
|
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
||||||
|
// copy s0 to last (d_conv - 1) columns of s
|
||||||
|
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||||
|
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
|
||||||
|
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
|
||||||
|
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
|
||||||
|
float * s0; // {d_conv - 1, d_inner, n_kv}
|
||||||
|
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||||
|
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
|
||||||
|
int ne0s0;
|
||||||
|
|
||||||
|
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
||||||
|
|
||||||
|
// avoid needing to copy the state for the first token
|
||||||
|
if (i2 == 0) {
|
||||||
|
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
|
||||||
|
ne0s0 = src0->ne[0];
|
||||||
|
} else {
|
||||||
|
// the source is the last (d_conv - 1) columns of the destination
|
||||||
|
s0 = s + 1;
|
||||||
|
ne0s0 = nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// d_inner
|
||||||
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
|
// shift state left
|
||||||
|
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
||||||
|
s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
|
||||||
|
}
|
||||||
|
// insert x on the last column
|
||||||
|
s[(nc - 1) + i1*nc] = x0[i1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle copies when there are multiple output states
|
||||||
|
for (int i3 = 1; i3 < n_kv; ++i3) {
|
||||||
|
int32_t seq = sq[i3];
|
||||||
|
if (0 <= seq && seq < n_kv) {
|
||||||
|
float * s1 = s + (seq - sq[0])*nc*nr;
|
||||||
|
memcpy(s1, s, nc*ir*sizeof(float));
|
||||||
|
} else {
|
||||||
|
// stop at negative or too big seq_ids
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// it seems a little faster when this is separate from the state shift
|
||||||
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
|
// rowwise dot product
|
||||||
|
float sumf = 0.0f;
|
||||||
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
|
int i = i0 + i1*nc;
|
||||||
|
sumf += s[i] * c[i];
|
||||||
|
}
|
||||||
|
x[i1] = sumf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_ssm_conv(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * src0,
|
||||||
|
const struct ggml_tensor * src1,
|
||||||
|
const struct ggml_tensor * src2,
|
||||||
|
const struct ggml_tensor * src3,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_ssm_conv_f32(params, src0, src1, src2, src3, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_ssm_scan
|
// ggml_compute_forward_ssm_scan
|
||||||
|
|
||||||
static void ggml_compute_forward_ssm_scan_f32(
|
static void ggml_compute_forward_ssm_scan_f32(
|
||||||
|
@ -14656,6 +14835,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const struct ggml_tensor * src3, // A
|
const struct ggml_tensor * src3, // A
|
||||||
const struct ggml_tensor * src4, // B
|
const struct ggml_tensor * src4, // B
|
||||||
const struct ggml_tensor * src5, // C
|
const struct ggml_tensor * src5, // C
|
||||||
|
const struct ggml_tensor * src6, // sq
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
return;
|
return;
|
||||||
|
@ -14664,9 +14844,10 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
const int64_t nc = src0->ne[0]; // d_state
|
const int64_t nc = src0->ne[0]; // d_state
|
||||||
const int64_t nr = src0->ne[1]; // d_inner
|
const int64_t nr = src0->ne[1]; // d_inner
|
||||||
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
|
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
|
||||||
|
const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
|
||||||
|
|
||||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
|
@ -14675,9 +14856,11 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||||
// required for the dot product between s and C
|
// required for the dot product between s and C, and when copying the states
|
||||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||||
// required to get correct offset for state destination
|
// required for per-sequence offsets for states
|
||||||
|
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
||||||
|
// required to get correct offset for state destination (i.e. src1->nb[2])
|
||||||
GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
|
GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
|
@ -14688,16 +14871,37 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
const int ir = ir1 - ir0;
|
const int ir = ir1 - ir0;
|
||||||
|
|
||||||
// first token in the batch
|
if (n_kv > 1) {
|
||||||
{
|
// it's hard to know if the source states have already been copied
|
||||||
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
|
// when there are multiple, so copy them already.
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv}
|
for (int i3 = 0; i3 < n_kv; ++i3) {
|
||||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner, n_kv}
|
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
||||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
|
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
|
||||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
|
memcpy(s, s0, nc*ir*sizeof(float));
|
||||||
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
}
|
||||||
float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
|
}
|
||||||
float * C = (float *) ((char *) src5->data); // {d_state, n_tokens}
|
|
||||||
|
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||||
|
int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
|
||||||
|
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||||
|
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
|
||||||
|
float * s0;
|
||||||
|
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||||
|
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
|
||||||
|
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||||
|
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
|
||||||
|
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
|
||||||
|
|
||||||
|
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
||||||
|
|
||||||
|
// avoid needing to copy the state for the first token
|
||||||
|
if (i2 == 0) {
|
||||||
|
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
|
||||||
|
} else {
|
||||||
|
// otherwise the source is the same as the destination
|
||||||
|
s0 = s;
|
||||||
|
}
|
||||||
|
|
||||||
// d_inner
|
// d_inner
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
float dt_soft_plus = log1pf(expf(dt[i1]));
|
float dt_soft_plus = log1pf(expf(dt[i1]));
|
||||||
|
@ -14707,41 +14911,24 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
for (int i0 = 0; i0 < nc; ++i0) {
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
int i = i0 + i1*nc;
|
int i = i0 + i1*nc;
|
||||||
// state = prev_state * dA + dB * x
|
// state = prev_state * dA + dB * x
|
||||||
float state = s0[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||||
// y = rowwise_dotprod(state, C)
|
// y = rowwise_dotprod(state, C)
|
||||||
sumf += state*C[i0];
|
sumf += state * C[i0];
|
||||||
// FIXME: handle simultaneous sequences
|
|
||||||
s[i] = state;
|
s[i] = state;
|
||||||
}
|
}
|
||||||
y[i1] = sumf;
|
y[i1] = sumf;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// rest of the batch, state comes from previous one which was stored in destination
|
// handle copies when there are multiple output states
|
||||||
for (int i2 = 1; i2 < n_t; ++i2) {
|
for (int i3 = 1; i3 < n_kv; ++i3) {
|
||||||
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
int32_t seq = sq[i3];
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv}
|
if (0 <= seq && seq < n_kv) {
|
||||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
float * s1 = s + (seq - sq[0])*nc*nr;
|
||||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
|
memcpy(s1, s, nc*ir*sizeof(float));
|
||||||
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
} else {
|
||||||
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
|
// stop at negative or too big seq_ids
|
||||||
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
|
break;
|
||||||
// d_inner
|
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
|
||||||
float dt_soft_plus = log1pf(expf(dt[i1]));
|
|
||||||
float x_dt = x[i1] * dt_soft_plus;
|
|
||||||
float sumf = 0.0f;
|
|
||||||
// d_state
|
|
||||||
for (int i0 = 0; i0 < nc; ++i0) {
|
|
||||||
int i = i0 + i1*nc;
|
|
||||||
// state = prev_state * dA + dB * x
|
|
||||||
float state = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
|
||||||
// y = rowwise_dotprod(state, C)
|
|
||||||
sumf += state*C[i0];
|
|
||||||
// FIXME: handle simultaneous sequences
|
|
||||||
s[i] = state;
|
|
||||||
}
|
}
|
||||||
y[i1] = sumf;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14754,11 +14941,12 @@ static void ggml_compute_forward_ssm_scan(
|
||||||
const struct ggml_tensor * src3,
|
const struct ggml_tensor * src3,
|
||||||
const struct ggml_tensor * src4,
|
const struct ggml_tensor * src4,
|
||||||
const struct ggml_tensor * src5,
|
const struct ggml_tensor * src5,
|
||||||
|
const struct ggml_tensor * src6,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, dst);
|
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, src6, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
@ -15818,9 +16006,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
bool masked = t != 0;
|
bool masked = t != 0;
|
||||||
ggml_compute_forward_flash_attn_back(params, masked, tensor);
|
ggml_compute_forward_flash_attn_back(params, masked, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_ssm_conv(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor);
|
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor->src[6], tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_WIN_PART:
|
case GGML_OP_WIN_PART:
|
||||||
{
|
{
|
||||||
|
@ -16868,6 +17060,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false); // not supported
|
GGML_ASSERT(false); // not supported
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false); // TODO: not implemented
|
GGML_ASSERT(false); // TODO: not implemented
|
||||||
|
@ -17569,6 +17762,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
|
|
11
ggml.h
11
ggml.h
|
@ -460,6 +460,7 @@ extern "C" {
|
||||||
GGML_OP_FLASH_ATTN,
|
GGML_OP_FLASH_ATTN,
|
||||||
GGML_OP_FLASH_FF,
|
GGML_OP_FLASH_FF,
|
||||||
GGML_OP_FLASH_ATTN_BACK,
|
GGML_OP_FLASH_ATTN_BACK,
|
||||||
|
GGML_OP_SSM_CONV,
|
||||||
GGML_OP_SSM_SCAN,
|
GGML_OP_SSM_SCAN,
|
||||||
GGML_OP_WIN_PART,
|
GGML_OP_WIN_PART,
|
||||||
GGML_OP_WIN_UNPART,
|
GGML_OP_WIN_UNPART,
|
||||||
|
@ -1702,6 +1703,13 @@ extern "C" {
|
||||||
struct ggml_tensor * c0,
|
struct ggml_tensor * c0,
|
||||||
struct ggml_tensor * c1);
|
struct ggml_tensor * c1);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * s,
|
||||||
|
struct ggml_tensor * x,
|
||||||
|
struct ggml_tensor * c,
|
||||||
|
struct ggml_tensor * sq);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * s,
|
struct ggml_tensor * s,
|
||||||
|
@ -1709,7 +1717,8 @@ extern "C" {
|
||||||
struct ggml_tensor * dt,
|
struct ggml_tensor * dt,
|
||||||
struct ggml_tensor * A,
|
struct ggml_tensor * A,
|
||||||
struct ggml_tensor * B,
|
struct ggml_tensor * B,
|
||||||
struct ggml_tensor * C);
|
struct ggml_tensor * C,
|
||||||
|
struct ggml_tensor * sq);
|
||||||
|
|
||||||
// partition into non-overlapping windows with padding if needed
|
// partition into non-overlapping windows with padding if needed
|
||||||
// example:
|
// example:
|
||||||
|
|
136
llama.cpp
136
llama.cpp
|
@ -2044,6 +2044,7 @@ struct llama_context {
|
||||||
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
||||||
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
||||||
struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba)
|
struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba)
|
||||||
|
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
|
||||||
|
|
||||||
#ifdef GGML_USE_MPI
|
#ifdef GGML_USE_MPI
|
||||||
ggml_mpi_context * ctx_mpi = NULL;
|
ggml_mpi_context * ctx_mpi = NULL;
|
||||||
|
@ -4761,8 +4762,9 @@ static bool llm_load_tensors(
|
||||||
const int64_t d_conv = hparams.n_embd_head_k + 1;
|
const int64_t d_conv = hparams.n_embd_head_k + 1;
|
||||||
const int64_t d_state = hparams.n_embd_head_v;
|
const int64_t d_state = hparams.n_embd_head_v;
|
||||||
const int64_t d_inner = hparams.n_head;
|
const int64_t d_inner = hparams.n_head;
|
||||||
// FIXME: ceiling instead of floor
|
// TODO: allow loading dt_rank from the model config
|
||||||
const int64_t dt_rank = n_embd / 16;
|
// ceiling division
|
||||||
|
const int64_t dt_rank = (n_embd / 16) + (n_embd % 16 > 0);
|
||||||
GGML_ASSERT(2 * n_embd == d_inner);
|
GGML_ASSERT(2 * n_embd == d_inner);
|
||||||
|
|
||||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
|
@ -8012,13 +8014,12 @@ struct llm_build_context {
|
||||||
GGML_ASSERT(2 * d_model == d_inner);
|
GGML_ASSERT(2 * d_model == d_inner);
|
||||||
const int64_t d_conv = n_embd_head_k + 1;
|
const int64_t d_conv = n_embd_head_k + 1;
|
||||||
const int64_t d_state = n_embd_head_v;
|
const int64_t d_state = n_embd_head_v;
|
||||||
const int64_t dt_rank = d_model / 16;
|
// ceiling division
|
||||||
|
const int64_t dt_rank = (d_model / 16) + (d_model % 16 > 0);
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
struct ggml_tensor * inpL;
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
GGML_ASSERT(kv_self.used - kv_self.head + 1 == 1); // TODO: support more than one sequence per batch
|
|
||||||
|
|
||||||
// {n_embd, n_tokens}
|
// {n_embd, n_tokens}
|
||||||
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
|
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
|
||||||
cb(inpL, "inp_embd", -1);
|
cb(inpL, "inp_embd", -1);
|
||||||
|
@ -8040,8 +8041,8 @@ struct llm_build_context {
|
||||||
state_mask);
|
state_mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * conv_state = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
|
conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
|
||||||
struct ggml_tensor * ssm_state = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv);
|
ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv);
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
@ -8056,37 +8057,31 @@ struct llm_build_context {
|
||||||
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
|
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
|
||||||
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
|
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
|
||||||
|
|
||||||
|
struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0);
|
||||||
|
|
||||||
// conv
|
// conv
|
||||||
{
|
{
|
||||||
// concat last (d_conv - 1) columns of conv_state, and x
|
// Custom operator which is needed only to ease simultaneous sequence processing.
|
||||||
|
// For a single sequence, the equivalent is to concatenate the columns of conv_states and x,
|
||||||
|
// then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension,
|
||||||
|
// then element-wise multiply that with the conv1d weigth,
|
||||||
|
// then sum the elements of each row,
|
||||||
|
// (the last two steps are a dot product over rows (also doable with mul_mat))
|
||||||
|
// then permute away the ne[0] dimension,
|
||||||
|
// and then you're left with the resulting x tensor.
|
||||||
|
// The new conv_states is the last (d_conv - 1) columns
|
||||||
|
// of the last 3rd dimensional "layer" of the self-overlapping view.
|
||||||
|
// For simultaneous sequences, it's more complicated.
|
||||||
|
struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq);
|
||||||
|
|
||||||
// The following tensor is too big in order to avoid an assertion error when making an overlapping view.
|
// store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache
|
||||||
// TODO: in ggml_new_tensor_impl, handle overlapping data range in data size calculation
|
|
||||||
// This could then be a tensor with ne[] = {(d_conv-1)+n_tokens, d_inner},
|
|
||||||
// but the size difference is not that big (d_conv is usually 4).
|
|
||||||
struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tokens);
|
|
||||||
const size_t conv_x_nb1 = (d_conv - 1 + n_tokens) * ggml_element_size(conv_x);
|
|
||||||
|
|
||||||
conv_x = ggml_set_2d(ctx0, conv_x, conv_state, conv_x_nb1, 0);
|
|
||||||
// making x contiguous is necessary because ggml_set expects it
|
|
||||||
conv_x = ggml_set_2d(ctx0, conv_x, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_x_nb1, (d_conv - 1)*ggml_element_size(conv_x));
|
|
||||||
|
|
||||||
// store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state
|
|
||||||
ggml_build_forward_expand(gf,
|
ggml_build_forward_expand(gf,
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tokens*ggml_element_size(conv_x)),
|
ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
|
||||||
ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_x))));
|
ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
|
||||||
|
|
||||||
// prepare convolution for all tokens in the batch with a self-overlapping view,
|
// extract x from x_conv
|
||||||
// shifting by one column each ... depth? ... with a window of d_conv columns.
|
x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
|
||||||
// {(d_conv-1)+n_tokens, d_inner} => {d_conv, d_inner, n_tokens}
|
|
||||||
conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tokens, conv_x_nb1, 1*ggml_element_size(conv_x), 0);
|
|
||||||
|
|
||||||
// perform convolution
|
|
||||||
// => {1, d_inner, n_tokens}
|
|
||||||
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_x, model.layers[il].ssm_conv1d));
|
|
||||||
// => {d_inner, n_tokens, 1}
|
|
||||||
x = ggml_permute(ctx0, x, 2, 0, 1, 3);
|
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
|
x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
|
||||||
|
@ -8111,13 +8106,13 @@ struct llm_build_context {
|
||||||
// as described in the Annex D of the Mamba paper.
|
// as described in the Annex D of the Mamba paper.
|
||||||
// => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
|
// => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
|
||||||
// because only a single tensor can be returned.
|
// because only a single tensor can be returned.
|
||||||
struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B, C);
|
struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq);
|
||||||
|
|
||||||
// store last states (the second part of y_ssm_states)
|
// store last states (the second part of y_ssm_states)
|
||||||
ggml_build_forward_expand(gf,
|
ggml_build_forward_expand(gf,
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
|
ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
|
||||||
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_state))));
|
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_states))));
|
||||||
|
|
||||||
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
|
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
|
||||||
|
|
||||||
|
@ -8362,7 +8357,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
|
|
||||||
float * data = (float *) lctx.inp_KQ_mask->data;
|
float * data = (float *) lctx.inp_KQ_mask->data;
|
||||||
|
|
||||||
// For Transformers, use only the previous KV cells
|
// For Transformers, use only the previous KV cells (or all, when non-causal)
|
||||||
// of the correct sequence for each token of the batch.
|
// of the correct sequence for each token of the batch.
|
||||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
|
@ -8382,13 +8377,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// For Mamba (and other constant-time-and-size architectures),
|
|
||||||
// update the correct state(s)/sequence(s) for each token of the batch.
|
|
||||||
// Source and destination states are both the same for the sake of implementation simplicity.
|
|
||||||
// It would be more complex if they were sometimes the same and somtimes not.
|
|
||||||
// (with Transformers, source KV cells are never the destination,
|
|
||||||
// which is also simpler, but more memory hungry)
|
|
||||||
// TODO: implement
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hparams.need_kq_pos) {
|
if (hparams.need_kq_pos) {
|
||||||
|
@ -8447,28 +8435,54 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kv_self.unlimited) {
|
if (kv_self.unlimited) {
|
||||||
const uint32_t kv_size = kv_self.size;
|
const int64_t kv_size = kv_self.size;
|
||||||
const uint32_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self.n;
|
||||||
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
|
{
|
||||||
float * data = (float *) lctx.inp_s_mask->data;
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
|
||||||
|
float * data = (float *) lctx.inp_s_mask->data;
|
||||||
|
|
||||||
// states which are not affected by the current batch are left untouched
|
// states which are not affected by the current batch are left untouched
|
||||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
llama_seq_id seq_id = i + lctx.kv_self.head;
|
llama_seq_id seq_id = i + lctx.kv_self.head;
|
||||||
llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
|
llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
|
||||||
bool has_self_seq = kv_cell.has_seq_id(seq_id);
|
bool has_self_seq = kv_cell.has_seq_id(seq_id);
|
||||||
|
|
||||||
data[i] = (float) has_self_seq;
|
data[i] = (float) has_self_seq;
|
||||||
|
|
||||||
// ensure current sequences will be kept
|
// ensure current sequences will be kept
|
||||||
if (!has_self_seq) {
|
if (!has_self_seq) {
|
||||||
kv_cell.seq_id.insert(seq_id);
|
kv_cell.seq_id.insert(seq_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// For Mamba (and other constant-time-and-size architectures),
|
||||||
|
// update the correct state(s)/sequence(s) for each token of the batch.
|
||||||
|
// Like with the KQ_mask, if a token in the batch has multiple sequences,
|
||||||
|
// they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
|
||||||
|
{
|
||||||
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
|
||||||
|
int32_t * data = (int32_t *) lctx.inp_s_seq->data;
|
||||||
|
|
||||||
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
|
const int32_t n_seq = batch.n_seq_id[j];
|
||||||
|
GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence
|
||||||
|
|
||||||
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
|
if (i < n_seq) {
|
||||||
|
// for this type of model, the head is the minimum seq_id of the batch
|
||||||
|
data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head;
|
||||||
|
} else {
|
||||||
|
data[j*n_kv + i] = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// remove extraneous seq_ids when state copies are made
|
// remove extraneous seq_ids when state copies are made
|
||||||
{
|
{
|
||||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
for (int i = 0; i < kv_size; ++i) {
|
||||||
llama_kv_cell & kv_cell = lctx.kv_self.cells[i];
|
llama_kv_cell & kv_cell = lctx.kv_self.cells[i];
|
||||||
uint32_t n_seqs = kv_cell.seq_id.size();
|
uint32_t n_seqs = kv_cell.seq_id.size();
|
||||||
bool has_self_seq = kv_cell.has_seq_id(i);
|
bool has_self_seq = kv_cell.has_seq_id(i);
|
||||||
|
@ -12642,7 +12656,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
// graph inputs
|
// graph inputs
|
||||||
{
|
{
|
||||||
ggml_init_params init_params = {
|
ggml_init_params init_params = {
|
||||||
/* .mem_size */ ggml_tensor_overhead()*(8 + ctx->kv_self.unlimited),
|
/* .mem_size */ ggml_tensor_overhead()*(8 + 2*(ctx->kv_self.unlimited)),
|
||||||
/* .mem_buffer */ nullptr,
|
/* .mem_buffer */ nullptr,
|
||||||
/* .no_alloc */ true,
|
/* .no_alloc */ true,
|
||||||
};
|
};
|
||||||
|
@ -12656,8 +12670,10 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
|
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
|
||||||
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
|
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
|
||||||
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||||
if (ctx->kv_self.unlimited)
|
if (ctx->kv_self.unlimited) {
|
||||||
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
|
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
|
||||||
|
ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_set_name(ctx->inp_tokens, "inp_tokens");
|
ggml_set_name(ctx->inp_tokens, "inp_tokens");
|
||||||
ggml_set_name(ctx->inp_embd, "inp_embd");
|
ggml_set_name(ctx->inp_embd, "inp_embd");
|
||||||
|
@ -12667,8 +12683,10 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
|
ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
|
||||||
ggml_set_name(ctx->inp_mean, "inp_mean");
|
ggml_set_name(ctx->inp_mean, "inp_mean");
|
||||||
ggml_set_name(ctx->inp_cls, "inp_cls");
|
ggml_set_name(ctx->inp_cls, "inp_cls");
|
||||||
if (ctx->kv_self.unlimited)
|
if (ctx->kv_self.unlimited) {
|
||||||
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
|
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
|
||||||
|
ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
|
||||||
|
}
|
||||||
|
|
||||||
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
|
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
|
||||||
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
|
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue