mamba : multiple sequences, but one at a time

This is a step towards making this Mamba implementation usable
with the server example (the way the system prompt is kept when clearing
the client slots will need to be changed before this can work, though).

The KV cache size for this kind of model is tied to the maximum number
of sequences kept at any single time.
For now, this number is obtained from n_parallel (plus one,
to have an extra sequence to dedicate to the system prompt),
but there might be a better way to do this which won't also
make the main example use 2 cells even if only 1 is really used.
(for this specific case, --parallel 0 helps)

Simultaneous sequence processing will probably require changes to
ggml_ssm_scan, and possibly a new operator for the conv step.

* mamba : support llama_kv_cache_seq_cp

This (mis)uses the logic around K shifts, because tokens in a state
can't be shifted anyway, and because inp_K_shift has the right shape and type.
Using ggml_get_rows is a nice way to do copies, but copy chains can't work.
Fortunately, copy chains don't really seem to be used in the examples.

Each KV cell is dedicated to the sequence ID corresponding to its own index.

* mamba : use a state mask

It's cleaner than the previous heuristic of
checking for the pos of the first token in the batch.

inp_KQ_mask could not be re-used for this, because it has the wrong shape
and because it seems more suited to the next step of
simultaneous sequence processing (helping with the problem of
remembering which token belongs to which sequence(s)/state(s)).

* llama : replace the usage of n_ctx with kv_self.size in many places

* mamba : use n_tokens directly instead of n_tok
This commit is contained in:
Francis Couture-Harpin 2024-02-13 19:06:18 -05:00
parent 6ff34da092
commit 8a43ffcfa1
4 changed files with 256 additions and 84 deletions

View file

@ -1295,6 +1295,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.n_parallel = params.n_parallel;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.seed = params.seed;

22
ggml.c
View file

@ -6099,13 +6099,13 @@ struct ggml_tensor * ggml_ssm_scan(
{
const int64_t d_state = s->ne[0];
const int64_t d_inner = s->ne[1];
const int64_t n_tok = x->ne[1];
const int64_t n_tokens = x->ne[1];
GGML_ASSERT(x->ne[0] == d_inner);
GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == d_inner);
GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[1] == n_tok);
GGML_ASSERT(B->ne[1] == n_tokens);
}
bool is_node = false;
@ -14682,12 +14682,12 @@ static void ggml_compute_forward_ssm_scan_f32(
// first batch
{
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tokens}
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
float * B = (float *) ((char *) src4->data); // {d_state, n_tok}
float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
float dt_soft_plus = log1pf(expf(dt[i1]));
@ -14703,12 +14703,12 @@ static void ggml_compute_forward_ssm_scan_f32(
// compute state for rest of tokens, previous state comes from dest
for (int i2 = 1; i2 < n_t; ++i2) {
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tokens}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tokens}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tokens}
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tok}
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
float dt_soft_plus = log1pf(expf(dt[i1]));

298
llama.cpp
View file

@ -1802,6 +1802,8 @@ struct llama_kv_cell {
struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
// with Mamba, a slot can hold the state for more than one past token
bool unlimited = false;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it
@ -2036,11 +2038,12 @@ struct llama_context {
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx]
struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_KQ_pos; // F32 [kv_size]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba)
#ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = NULL;
@ -2056,7 +2059,7 @@ static bool llama_kv_cache_init(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
uint32_t n_ctx,
uint32_t kv_size,
bool offload) {
const struct llama_hparams & hparams = model.hparams;
@ -2064,22 +2067,26 @@ static bool llama_kv_cache_init(
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const int64_t n_layer = hparams.n_layer;
if (model.arch == LLM_ARCH_MAMBA) {
// only one slot is needed for Mamba
n_ctx = 1;
}
cache.has_shift = false;
// for now, only Mamba can hold state for more than one past token per slot
cache.unlimited = model.arch == LLM_ARCH_MAMBA;
cache.head = 0;
cache.size = n_ctx;
cache.size = kv_size;
cache.used = 0;
cache.type_k = type_k;
cache.type_v = type_v;
cache.cells.clear();
cache.cells.resize(n_ctx);
cache.cells.resize(kv_size);
if (cache.unlimited) {
for (uint32_t i = 0; i < cache.size; ++i) {
cache.cells[i].delta = i;
}
} // else, delta is already initialized to zero
#ifdef GGML_USE_CLBLAST
offload = false;
@ -2118,8 +2125,8 @@ static bool llama_kv_cache_init(
for (int i = 0; i < (int) n_layer; i++) {
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*n_ctx);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*n_ctx);
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
@ -2153,11 +2160,51 @@ static bool llama_kv_cache_find_slot(
const uint32_t n_ctx = cache.size;
const uint32_t n_tokens = batch.n_tokens;
// for Mamba and/or other model archs that only ever use one slot
if (n_ctx == 1) {
// hopefully no one actually uses a context size of 1 on Transformer-based models
return true;
if (cache.unlimited) {
// For unlimited context architectures (like Mamba),
// each KV cache cell can store the state for a whole sequence.
// starting point to find the minimum seq_id used in the batch
cache.head = cache.size - 1;
// likewise, to find the max seq_id in the batch
cache.used = 0;
for (uint32_t i = 0; i < n_tokens; ++i) {
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
llama_seq_id seq_id = batch.seq_id[i][j];
// make sure it's a valid seq_id
if ((uint32_t)seq_id < cache.size) {
// the number of "used" cells is simply the biggest seq_id
if (cache.used < (uint32_t)seq_id) {
cache.used = seq_id;
}
// the "head" is the smallest seq_id
if (cache.head > (uint32_t)seq_id) {
cache.head = seq_id;
}
// Assuming the tokens are in-order
if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
// What should happen when the pos backtracks?
// Clearing the state mid-batch would require special-casing which isn't done.
LLAMA_LOG_ERROR("%s: non-consecutive token position %d after %d for sequence %d\n",
__func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
return false;
}
cache.cells[seq_id].pos = batch.pos[i];
// NOTE: seq_ids are not inserted here, because they are handled when the graph is built.
} else {
// too big seq_id
// TODO: would it be possible to resize the KV cache size instead?
LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d\n", __func__, seq_id, cache.size);
return false;
}
}
}
cache.n = cache.used - cache.head + 1;
// sanity check (max >= min)
return cache.used >= cache.head;
}
// otherwise, one cell per token.
if (n_tokens > n_ctx) {
LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
@ -2238,6 +2285,13 @@ static void llama_kv_cache_seq_rm(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
if (cache.unlimited) {
// can only remove whole sequences for models like Mamba
GGML_ASSERT(p0 == 0);
GGML_ASSERT((uint32_t)seq_id < cache.size);
GGML_ASSERT(cache.cells[seq_id].pos < p1);
}
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
if (seq_id < 0) {
@ -2270,6 +2324,26 @@ static void llama_kv_cache_seq_cp(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
if (cache.unlimited) {
if ((uint32_t)seq_id_dst < cache.size && (uint32_t)seq_id_src < cache.size) {
// intent to "copy from" (does not support copy chains)
cache.cells[seq_id_dst].delta = seq_id_src;
// NOTE: a sequence can't have multiple sources, but can have multiple destinations.
// For compatibility with the other KV cache API functions,
// the seq_id(s) of a slot suggests an intent to "copy to" those id(s),
// so that when a sequence is copied, it can initially be found from the source cell.
cache.cells[seq_id_src].seq_id.insert(seq_id_dst);
// prevent the destination from getting cleared
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
// repurposed as a "need copy" flag
// (shifting can't be done anyway for this kind of KV cache)
cache.has_shift = seq_id_src != seq_id_dst;
// NOTE: this is not correct for sequence swaps (which aren't a thing in the KV cache API yet)
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
}
return;
}
cache.head = 0;
for (uint32_t i = 0; i < cache.size; ++i) {
@ -2309,6 +2383,10 @@ static void llama_kv_cache_seq_add(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
if (cache.unlimited) {
GGML_ASSERT(false); // not supported
}
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.has_shift = true;
@ -2342,6 +2420,10 @@ static void llama_kv_cache_seq_div(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
if (cache.unlimited) {
GGML_ASSERT(false); // not supported
}
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.has_shift = true;
@ -4943,6 +5025,8 @@ static void llm_build_kv_store(
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(kv.size == n_ctx);
// compute the transposed [n_tokens, n_embd] V matrix
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
@ -5152,6 +5236,8 @@ static struct ggml_tensor * llm_build_kqv(
cb(kq, "kq_soft_max_ext", il);
}
GGML_ASSERT(kv.size == n_ctx);
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
@ -5299,7 +5385,7 @@ struct llm_build_context {
norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (batch.n_tokens),
n_kv (worst_case ? kv_self.size : kv_self.n),
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
kv_head (worst_case ? (kv_self.unlimited ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
@ -5328,6 +5414,22 @@ struct llm_build_context {
struct ggml_cgraph * build_k_shift() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
// TODO: do this in a another graph with a dedicated input tensor
if (kv_self.unlimited) {
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_K_shift);
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_K_shift);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
}
return gf;
}
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * tmp =
// we rotate only the first n_rot dimensions
@ -7905,8 +8007,6 @@ struct llm_build_context {
struct ggml_cgraph * build_mamba() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int32_t n_tok = batch.n_tokens;
const int64_t d_model = n_embd;
const int64_t d_inner = n_head;
GGML_ASSERT(2 * d_model == d_inner);
@ -7917,33 +8017,45 @@ struct llm_build_context {
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
// {n_embd, n_tok}
GGML_ASSERT(kv_self.used - kv_self.head + 1 == 1); // TODO: support more than one sequence per batch
// {n_embd, n_tokens}
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
cb(inpL, "inp_embd", -1);
for (int il = 0; il < n_layer; ++il) {
// (ab)using the kv cache to store the state
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv - 1, d_inner);
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], (d_conv-1)*(d_inner), kv_self.size);
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], (d_state)*(d_inner), kv_self.size);
// reset the states when starting a new sequence
// TODO: ensure kv_self clearing is handled
if (!batch.pos || batch.pos[0] == 0) {
conv_state = ggml_scale(ctx0, conv_state, 0);
ssm_state = ggml_scale(ctx0, ssm_state, 0);
{
ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
// clear states of sequences which are starting at the beginning of this batch
conv_states = ggml_mul(ctx0,
ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
state_mask);
ssm_states = ggml_mul(ctx0,
ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]),
state_mask);
}
// TODO: support more than one sequence per batch (these could then use ggml_reshape_3d)
ggml_tensor * conv_state = ggml_view_2d(ctx0, conv_states, d_conv - 1, d_inner,
(d_conv - 1)*ggml_element_size(conv_states), 0);
ggml_tensor * ssm_state = ggml_view_2d(ctx0, ssm_states, d_state, d_inner,
(d_state)*ggml_element_size(ssm_states), 0);
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// {n_embd, 2*d_inner} * {n_embd, n_tok} => {2*d_inner, n_tok}
// {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
// split the above in two
// => {d_inner, n_tok}
// => {d_inner, n_tokens}
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
@ -7953,10 +8065,10 @@ struct llm_build_context {
// The following tensor is too big in order to avoid an assertion error when making an overlapping view.
// TODO: in ggml_new_tensor_impl, handle overlapping data range in data size calculation
// This could then be a tensor with ne[] = {(d_conv-1)+n_tok, d_inner},
// This could then be a tensor with ne[] = {(d_conv-1)+n_tokens, d_inner},
// but the size difference is not that big (d_conv is usually 4).
struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tok);
const size_t conv_x_nb1 = (d_conv - 1 + n_tok) * ggml_element_size(conv_x);
struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tokens);
const size_t conv_x_nb1 = (d_conv - 1 + n_tokens) * ggml_element_size(conv_x);
conv_x = ggml_set_2d(ctx0, conv_x, conv_state, conv_x_nb1, 0);
// making x contiguous is necessary because ggml_set expects it
@ -7965,18 +8077,18 @@ struct llm_build_context {
// store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tok*ggml_element_size(conv_x)),
ggml_view_tensor(ctx0, kv_self.k_l[il])));
ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tokens*ggml_element_size(conv_x)),
ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_x))));
// prepare convolution for all tokens in the batch with a self-overlapping view,
// shifting by one column each ... depth? ... with a window of d_conv columns.
// {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok}
conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, 1*ggml_element_size(conv_x), 0);
// {(d_conv-1)+n_tokens, d_inner} => {d_conv, d_inner, n_tokens}
conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tokens, conv_x_nb1, 1*ggml_element_size(conv_x), 0);
// perform convolution
// => {1, d_inner, n_tok}
// => {1, d_inner, n_tokens}
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_x, model.layers[il].ssm_conv1d));
// => {d_inner, n_tok, 1}
// => {d_inner, n_tokens, 1}
x = ggml_permute(ctx0, x, 2, 0, 1, 3);
// bias
@ -7987,38 +8099,38 @@ struct llm_build_context {
// ssm
{
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tok} => {dt_rank + 2*d_state, n_tok}
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
// split
struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tok, x_db->nb[1], 0);
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
// {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok}
// {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
// Custom operator to implement some of the optimizations
// described in the Annex D of the Mamba paper.
// TODO: maybe also optimize step 4 of the Speed section of Annex D (the mul_mat with C)
// => {d_state, d_inner, n_tok}
// => {d_state, d_inner, n_tokens}
ssm_state = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B);
// only store last state
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]),
ggml_view_tensor(ctx0, kv_self.v_l[il])));
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tokens-1)*ssm_state->nb[2]),
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner, kv_self.head*d_state*d_inner*ggml_element_size(ssm_state))));
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
// {d_state, d_inner, n_tokens} * {d_state, n_tokens} => {d_inner, 1, n_tokens}
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
// => {d_inner, n_tok}
// => {d_inner, n_tokens}
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
// {d_inner, n_tok} * {d_inner} => {d_inner, n_tok}
// {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
// {d_inner, n_embd} * {d_inner, n_tok} => {n_embd, n_tok}
// {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
}
@ -8208,15 +8320,13 @@ static struct ggml_cgraph * llama_build_graph(
}
static void llama_set_k_shift(llama_context & lctx) {
const auto & cparams = lctx.cparams;
const int64_t n_ctx = cparams.n_ctx;
const int64_t kv_size = lctx.kv_self.size;
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
for (int i = 0; i < n_ctx; ++i) {
for (int i = 0; i < kv_size; ++i) {
data[i] = lctx.kv_self.cells[i].delta;
}
}
@ -8257,6 +8367,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
float * data = (float *) lctx.inp_KQ_mask->data;
// For Transformers, use only the previous KV cells
// of the correct sequence for each token of the batch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
@ -8274,6 +8387,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
}
// For Mamba (and other constant-time-and-size architectures),
// update the correct state(s)/sequence(s) for each token of the batch.
// Source and destination states are both the same for the sake of implementation simplicity.
// It would be more complex if they were sometimes the same and somtimes not.
// (with Transformers, source KV cells are never the destination,
// which is also simpler, but more memory hungry)
// TODO: implement
}
if (hparams.need_kq_pos) {
@ -8330,6 +8450,43 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
}
if (kv_self.unlimited) {
const uint32_t kv_size = kv_self.size;
const uint32_t n_kv = kv_self.n;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
float * data = (float *) lctx.inp_s_mask->data;
// states which are not affected by the current batch are left untouched
for (uint32_t i = 0; i < n_kv; ++i) {
llama_seq_id seq_id = i + lctx.kv_self.head;
llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
bool has_self_seq = kv_cell.has_seq_id(seq_id);
data[i] = (float) has_self_seq;
// ensure current sequences will be kept
if (!has_self_seq) {
kv_cell.seq_id.insert(seq_id);
}
}
// remove extraneous seq_ids when state copies are made
{
for (uint32_t i = 0; i < kv_size; ++i) {
llama_kv_cell & kv_cell = lctx.kv_self.cells[i];
uint32_t n_seqs = kv_cell.seq_id.size();
bool has_self_seq = kv_cell.has_seq_id(i);
if (has_self_seq && n_seqs > 1) {
kv_cell.seq_id.clear();
kv_cell.seq_id.insert(i);
} else if (!has_self_seq && n_seqs > 0) {
kv_cell.seq_id.clear();
}
}
}
}
}
static void llama_graph_compute(
@ -8450,6 +8607,7 @@ static int llama_decode_internal(
return 1;
}
if (!kv_self.unlimited) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
@ -8457,6 +8615,7 @@ static int llama_decode_internal(
//kv_self.n = llama_kv_cache_cell_max(kv_self);
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
}
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
@ -8817,7 +8976,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
if ((lctx.kv_self.unlimited || lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) && lctx.kv_self.has_shift) {
llama_set_k_shift(lctx);
{
@ -8832,7 +8991,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
kv_self.has_shift = false;
for (uint32_t i = 0; i < kv_self.size; ++i) {
kv_self.cells[i].delta = 0;
kv_self.cells[i].delta = kv_self.unlimited ? i : 0;
}
}
}
@ -12122,6 +12281,7 @@ struct llama_context_params llama_context_default_params() {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512,
/*.n_batch =*/ 512,
/*.n_parallel =*/ 1,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
@ -12283,6 +12443,7 @@ struct llama_context * llama_new_context_with_model(
auto & cparams = ctx->cparams;
cparams.n_batch = params.n_batch;
// TODO: maybe add n_parallel here too
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@ -12339,14 +12500,19 @@ struct llama_context * llama_new_context_with_model(
ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;
uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;
// Mamba (mis)uses the KV cache to store its states
// Mamba only needs a constant number of KV cache slots per sequence
if (model->arch == LLM_ARCH_MAMBA) {
// Mamba needs as many slots as there are distinct sequences processed at the same time
// The extra slot allows dedicating a sequence id to the system prompt
// TODO: find a better way to get the max number of parallel sequences
kv_size = params.n_parallel + 1;
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state
type_v = GGML_TYPE_F32; // required by ggml_mul for Mamba's ssm_state
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_state
}
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
@ -12447,7 +12613,7 @@ struct llama_context * llama_new_context_with_model(
}
ctx->backends.push_back(ctx->backend_cpu);
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) {
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
@ -12481,7 +12647,7 @@ struct llama_context * llama_new_context_with_model(
// graph inputs
{
ggml_init_params init_params = {
/* .mem_size */ ggml_tensor_overhead()*8,
/* .mem_size */ ggml_tensor_overhead()*(8 + ctx->kv_self.unlimited),
/* .mem_buffer */ nullptr,
/* .no_alloc */ true,
};
@ -12490,11 +12656,13 @@ struct llama_context * llama_new_context_with_model(
ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, kv_size, cparams.n_batch);
ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
if (ctx->kv_self.unlimited)
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
ggml_set_name(ctx->inp_tokens, "inp_tokens");
ggml_set_name(ctx->inp_embd, "inp_embd");
@ -12504,6 +12672,8 @@ struct llama_context * llama_new_context_with_model(
ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
ggml_set_name(ctx->inp_mean, "inp_mean");
ggml_set_name(ctx->inp_cls, "inp_cls");
if (ctx->kv_self.unlimited)
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,

View file

@ -235,6 +235,7 @@ extern "C" {
uint32_t seed; // RNG seed, -1 for random
uint32_t n_ctx; // text context, 0 = from model
uint32_t n_batch; // prompt processing maximum batch size
uint32_t n_parallel; // number of parallel sequences
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing