mamba : multiple sequences, but one at a time
This is a step towards making this Mamba implementation usable with the server example (the way the system prompt is kept when clearing the client slots will need to be changed before this can work, though). The KV cache size for this kind of model is tied to the maximum number of sequences kept at any single time. For now, this number is obtained from n_parallel (plus one, to have an extra sequence to dedicate to the system prompt), but there might be a better way to do this which won't also make the main example use 2 cells even if only 1 is really used. (for this specific case, --parallel 0 helps) Simultaneous sequence processing will probably require changes to ggml_ssm_scan, and possibly a new operator for the conv step. * mamba : support llama_kv_cache_seq_cp This (mis)uses the logic around K shifts, because tokens in a state can't be shifted anyway, and because inp_K_shift has the right shape and type. Using ggml_get_rows is a nice way to do copies, but copy chains can't work. Fortunately, copy chains don't really seem to be used in the examples. Each KV cell is dedicated to the sequence ID corresponding to its own index. * mamba : use a state mask It's cleaner than the previous heuristic of checking for the pos of the first token in the batch. inp_KQ_mask could not be re-used for this, because it has the wrong shape and because it seems more suited to the next step of simultaneous sequence processing (helping with the problem of remembering which token belongs to which sequence(s)/state(s)). * llama : replace the usage of n_ctx with kv_self.size in many places * mamba : use n_tokens directly instead of n_tok
This commit is contained in:
parent
6ff34da092
commit
8a43ffcfa1
4 changed files with 256 additions and 84 deletions
|
@ -1295,6 +1295,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
|
||||
cparams.n_ctx = params.n_ctx;
|
||||
cparams.n_batch = params.n_batch;
|
||||
cparams.n_parallel = params.n_parallel;
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||
cparams.seed = params.seed;
|
||||
|
|
22
ggml.c
22
ggml.c
|
@ -6099,13 +6099,13 @@ struct ggml_tensor * ggml_ssm_scan(
|
|||
{
|
||||
const int64_t d_state = s->ne[0];
|
||||
const int64_t d_inner = s->ne[1];
|
||||
const int64_t n_tok = x->ne[1];
|
||||
const int64_t n_tokens = x->ne[1];
|
||||
|
||||
GGML_ASSERT(x->ne[0] == d_inner);
|
||||
GGML_ASSERT(A->ne[0] == d_state);
|
||||
GGML_ASSERT(A->ne[1] == d_inner);
|
||||
GGML_ASSERT(B->ne[0] == d_state);
|
||||
GGML_ASSERT(B->ne[1] == n_tok);
|
||||
GGML_ASSERT(B->ne[1] == n_tokens);
|
||||
}
|
||||
|
||||
bool is_node = false;
|
||||
|
@ -14682,12 +14682,12 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
|
||||
// first batch
|
||||
{
|
||||
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
|
||||
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tokens}
|
||||
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
|
||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
|
||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
|
||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
|
||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
|
||||
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||
float * B = (float *) ((char *) src4->data); // {d_state, n_tok}
|
||||
float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
|
||||
// d_inner
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
float dt_soft_plus = log1pf(expf(dt[i1]));
|
||||
|
@ -14703,12 +14703,12 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
|
||||
// compute state for rest of tokens, previous state comes from dest
|
||||
for (int i2 = 1; i2 < n_t; ++i2) {
|
||||
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
|
||||
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
|
||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
|
||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
|
||||
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tokens}
|
||||
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tokens}
|
||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tokens}
|
||||
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tok}
|
||||
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
|
||||
// d_inner
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
float dt_soft_plus = log1pf(expf(dt[i1]));
|
||||
|
|
298
llama.cpp
298
llama.cpp
|
@ -1802,6 +1802,8 @@ struct llama_kv_cell {
|
|||
struct llama_kv_cache {
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
// with Mamba, a slot can hold the state for more than one past token
|
||||
bool unlimited = false;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_internal also uses it, so it
|
||||
|
@ -2036,11 +2038,12 @@ struct llama_context {
|
|||
struct ggml_tensor * inp_tokens; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
||||
struct ggml_tensor * inp_pos; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
|
||||
struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx]
|
||||
struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
|
||||
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
|
||||
struct ggml_tensor * inp_KQ_pos; // F32 [kv_size]
|
||||
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
||||
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba)
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_context * ctx_mpi = NULL;
|
||||
|
@ -2056,7 +2059,7 @@ static bool llama_kv_cache_init(
|
|||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
uint32_t n_ctx,
|
||||
uint32_t kv_size,
|
||||
bool offload) {
|
||||
const struct llama_hparams & hparams = model.hparams;
|
||||
|
||||
|
@ -2064,22 +2067,26 @@ static bool llama_kv_cache_init(
|
|||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
|
||||
if (model.arch == LLM_ARCH_MAMBA) {
|
||||
// only one slot is needed for Mamba
|
||||
n_ctx = 1;
|
||||
}
|
||||
|
||||
cache.has_shift = false;
|
||||
|
||||
// for now, only Mamba can hold state for more than one past token per slot
|
||||
cache.unlimited = model.arch == LLM_ARCH_MAMBA;
|
||||
|
||||
cache.head = 0;
|
||||
cache.size = n_ctx;
|
||||
cache.size = kv_size;
|
||||
cache.used = 0;
|
||||
|
||||
cache.type_k = type_k;
|
||||
cache.type_v = type_v;
|
||||
|
||||
cache.cells.clear();
|
||||
cache.cells.resize(n_ctx);
|
||||
cache.cells.resize(kv_size);
|
||||
|
||||
if (cache.unlimited) {
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
cache.cells[i].delta = i;
|
||||
}
|
||||
} // else, delta is already initialized to zero
|
||||
|
||||
#ifdef GGML_USE_CLBLAST
|
||||
offload = false;
|
||||
|
@ -2118,8 +2125,8 @@ static bool llama_kv_cache_init(
|
|||
|
||||
for (int i = 0; i < (int) n_layer; i++) {
|
||||
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
||||
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*n_ctx);
|
||||
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*n_ctx);
|
||||
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
|
||||
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
||||
ggml_format_name(k, "cache_k_l%d", i);
|
||||
ggml_format_name(v, "cache_v_l%d", i);
|
||||
cache.k_l.push_back(k);
|
||||
|
@ -2153,11 +2160,51 @@ static bool llama_kv_cache_find_slot(
|
|||
const uint32_t n_ctx = cache.size;
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
|
||||
// for Mamba and/or other model archs that only ever use one slot
|
||||
if (n_ctx == 1) {
|
||||
// hopefully no one actually uses a context size of 1 on Transformer-based models
|
||||
return true;
|
||||
if (cache.unlimited) {
|
||||
// For unlimited context architectures (like Mamba),
|
||||
// each KV cache cell can store the state for a whole sequence.
|
||||
|
||||
// starting point to find the minimum seq_id used in the batch
|
||||
cache.head = cache.size - 1;
|
||||
// likewise, to find the max seq_id in the batch
|
||||
cache.used = 0;
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][j];
|
||||
// make sure it's a valid seq_id
|
||||
if ((uint32_t)seq_id < cache.size) {
|
||||
// the number of "used" cells is simply the biggest seq_id
|
||||
if (cache.used < (uint32_t)seq_id) {
|
||||
cache.used = seq_id;
|
||||
}
|
||||
// the "head" is the smallest seq_id
|
||||
if (cache.head > (uint32_t)seq_id) {
|
||||
cache.head = seq_id;
|
||||
}
|
||||
// Assuming the tokens are in-order
|
||||
if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
|
||||
// What should happen when the pos backtracks?
|
||||
// Clearing the state mid-batch would require special-casing which isn't done.
|
||||
LLAMA_LOG_ERROR("%s: non-consecutive token position %d after %d for sequence %d\n",
|
||||
__func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
|
||||
return false;
|
||||
}
|
||||
cache.cells[seq_id].pos = batch.pos[i];
|
||||
// NOTE: seq_ids are not inserted here, because they are handled when the graph is built.
|
||||
} else {
|
||||
// too big seq_id
|
||||
// TODO: would it be possible to resize the KV cache size instead?
|
||||
LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d\n", __func__, seq_id, cache.size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cache.n = cache.used - cache.head + 1;
|
||||
// sanity check (max >= min)
|
||||
return cache.used >= cache.head;
|
||||
}
|
||||
// otherwise, one cell per token.
|
||||
|
||||
if (n_tokens > n_ctx) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
|
||||
|
@ -2238,6 +2285,13 @@ static void llama_kv_cache_seq_rm(
|
|||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
if (cache.unlimited) {
|
||||
// can only remove whole sequences for models like Mamba
|
||||
GGML_ASSERT(p0 == 0);
|
||||
GGML_ASSERT((uint32_t)seq_id < cache.size);
|
||||
GGML_ASSERT(cache.cells[seq_id].pos < p1);
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
if (seq_id < 0) {
|
||||
|
@ -2270,6 +2324,26 @@ static void llama_kv_cache_seq_cp(
|
|||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
if (cache.unlimited) {
|
||||
if ((uint32_t)seq_id_dst < cache.size && (uint32_t)seq_id_src < cache.size) {
|
||||
// intent to "copy from" (does not support copy chains)
|
||||
cache.cells[seq_id_dst].delta = seq_id_src;
|
||||
// NOTE: a sequence can't have multiple sources, but can have multiple destinations.
|
||||
// For compatibility with the other KV cache API functions,
|
||||
// the seq_id(s) of a slot suggests an intent to "copy to" those id(s),
|
||||
// so that when a sequence is copied, it can initially be found from the source cell.
|
||||
cache.cells[seq_id_src].seq_id.insert(seq_id_dst);
|
||||
// prevent the destination from getting cleared
|
||||
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
|
||||
// repurposed as a "need copy" flag
|
||||
// (shifting can't be done anyway for this kind of KV cache)
|
||||
cache.has_shift = seq_id_src != seq_id_dst;
|
||||
// NOTE: this is not correct for sequence swaps (which aren't a thing in the KV cache API yet)
|
||||
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
cache.head = 0;
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
|
@ -2309,6 +2383,10 @@ static void llama_kv_cache_seq_add(
|
|||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
if (cache.unlimited) {
|
||||
GGML_ASSERT(false); // not supported
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
cache.has_shift = true;
|
||||
|
@ -2342,6 +2420,10 @@ static void llama_kv_cache_seq_div(
|
|||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
if (cache.unlimited) {
|
||||
GGML_ASSERT(false); // not supported
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
cache.has_shift = true;
|
||||
|
@ -4943,6 +5025,8 @@ static void llm_build_kv_store(
|
|||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(kv.size == n_ctx);
|
||||
|
||||
// compute the transposed [n_tokens, n_embd] V matrix
|
||||
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
|
||||
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
|
||||
|
@ -5152,6 +5236,8 @@ static struct ggml_tensor * llm_build_kqv(
|
|||
cb(kq, "kq_soft_max_ext", il);
|
||||
}
|
||||
|
||||
GGML_ASSERT(kv.size == n_ctx);
|
||||
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx, kv.v_l[il],
|
||||
|
@ -5299,7 +5385,7 @@ struct llm_build_context {
|
|||
norm_rms_eps (hparams.f_norm_rms_eps),
|
||||
n_tokens (batch.n_tokens),
|
||||
n_kv (worst_case ? kv_self.size : kv_self.n),
|
||||
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
|
||||
kv_head (worst_case ? (kv_self.unlimited ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
||||
n_orig_ctx (cparams.n_yarn_orig_ctx),
|
||||
pooling_type (cparams.pooling_type),
|
||||
rope_type (hparams.rope_type),
|
||||
|
@ -5328,6 +5414,22 @@ struct llm_build_context {
|
|||
struct ggml_cgraph * build_k_shift() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
// TODO: do this in a another graph with a dedicated input tensor
|
||||
if (kv_self.unlimited) {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
|
||||
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
|
||||
|
||||
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_K_shift);
|
||||
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_K_shift);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
|
||||
}
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * tmp =
|
||||
// we rotate only the first n_rot dimensions
|
||||
|
@ -7905,8 +8007,6 @@ struct llm_build_context {
|
|||
struct ggml_cgraph * build_mamba() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
const int32_t n_tok = batch.n_tokens;
|
||||
|
||||
const int64_t d_model = n_embd;
|
||||
const int64_t d_inner = n_head;
|
||||
GGML_ASSERT(2 * d_model == d_inner);
|
||||
|
@ -7917,33 +8017,45 @@ struct llm_build_context {
|
|||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
// {n_embd, n_tok}
|
||||
GGML_ASSERT(kv_self.used - kv_self.head + 1 == 1); // TODO: support more than one sequence per batch
|
||||
|
||||
// {n_embd, n_tokens}
|
||||
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
|
||||
cb(inpL, "inp_embd", -1);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// (ab)using the kv cache to store the state
|
||||
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
||||
ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv - 1, d_inner);
|
||||
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);
|
||||
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], (d_conv-1)*(d_inner), kv_self.size);
|
||||
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], (d_state)*(d_inner), kv_self.size);
|
||||
|
||||
// reset the states when starting a new sequence
|
||||
// TODO: ensure kv_self clearing is handled
|
||||
if (!batch.pos || batch.pos[0] == 0) {
|
||||
conv_state = ggml_scale(ctx0, conv_state, 0);
|
||||
ssm_state = ggml_scale(ctx0, ssm_state, 0);
|
||||
{
|
||||
ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
|
||||
// clear states of sequences which are starting at the beginning of this batch
|
||||
conv_states = ggml_mul(ctx0,
|
||||
ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
|
||||
state_mask);
|
||||
ssm_states = ggml_mul(ctx0,
|
||||
ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]),
|
||||
state_mask);
|
||||
}
|
||||
|
||||
// TODO: support more than one sequence per batch (these could then use ggml_reshape_3d)
|
||||
ggml_tensor * conv_state = ggml_view_2d(ctx0, conv_states, d_conv - 1, d_inner,
|
||||
(d_conv - 1)*ggml_element_size(conv_states), 0);
|
||||
ggml_tensor * ssm_state = ggml_view_2d(ctx0, ssm_states, d_state, d_inner,
|
||||
(d_state)*ggml_element_size(ssm_states), 0);
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// {n_embd, 2*d_inner} * {n_embd, n_tok} => {2*d_inner, n_tok}
|
||||
// {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
|
||||
struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
|
||||
// split the above in two
|
||||
// => {d_inner, n_tok}
|
||||
// => {d_inner, n_tokens}
|
||||
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
|
||||
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
|
||||
|
||||
|
@ -7953,10 +8065,10 @@ struct llm_build_context {
|
|||
|
||||
// The following tensor is too big in order to avoid an assertion error when making an overlapping view.
|
||||
// TODO: in ggml_new_tensor_impl, handle overlapping data range in data size calculation
|
||||
// This could then be a tensor with ne[] = {(d_conv-1)+n_tok, d_inner},
|
||||
// This could then be a tensor with ne[] = {(d_conv-1)+n_tokens, d_inner},
|
||||
// but the size difference is not that big (d_conv is usually 4).
|
||||
struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tok);
|
||||
const size_t conv_x_nb1 = (d_conv - 1 + n_tok) * ggml_element_size(conv_x);
|
||||
struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tokens);
|
||||
const size_t conv_x_nb1 = (d_conv - 1 + n_tokens) * ggml_element_size(conv_x);
|
||||
|
||||
conv_x = ggml_set_2d(ctx0, conv_x, conv_state, conv_x_nb1, 0);
|
||||
// making x contiguous is necessary because ggml_set expects it
|
||||
|
@ -7965,18 +8077,18 @@ struct llm_build_context {
|
|||
// store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0,
|
||||
ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tok*ggml_element_size(conv_x)),
|
||||
ggml_view_tensor(ctx0, kv_self.k_l[il])));
|
||||
ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tokens*ggml_element_size(conv_x)),
|
||||
ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_x))));
|
||||
|
||||
// prepare convolution for all tokens in the batch with a self-overlapping view,
|
||||
// shifting by one column each ... depth? ... with a window of d_conv columns.
|
||||
// {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok}
|
||||
conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, 1*ggml_element_size(conv_x), 0);
|
||||
// {(d_conv-1)+n_tokens, d_inner} => {d_conv, d_inner, n_tokens}
|
||||
conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tokens, conv_x_nb1, 1*ggml_element_size(conv_x), 0);
|
||||
|
||||
// perform convolution
|
||||
// => {1, d_inner, n_tok}
|
||||
// => {1, d_inner, n_tokens}
|
||||
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_x, model.layers[il].ssm_conv1d));
|
||||
// => {d_inner, n_tok, 1}
|
||||
// => {d_inner, n_tokens, 1}
|
||||
x = ggml_permute(ctx0, x, 2, 0, 1, 3);
|
||||
|
||||
// bias
|
||||
|
@ -7987,38 +8099,38 @@ struct llm_build_context {
|
|||
|
||||
// ssm
|
||||
{
|
||||
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tok} => {dt_rank + 2*d_state, n_tok}
|
||||
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
|
||||
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
|
||||
// split
|
||||
struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tok, x_db->nb[1], 0);
|
||||
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
|
||||
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
|
||||
struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
|
||||
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
|
||||
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
|
||||
|
||||
// {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok}
|
||||
// {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
|
||||
dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
|
||||
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
|
||||
|
||||
// Custom operator to implement some of the optimizations
|
||||
// described in the Annex D of the Mamba paper.
|
||||
// TODO: maybe also optimize step 4 of the Speed section of Annex D (the mul_mat with C)
|
||||
// => {d_state, d_inner, n_tok}
|
||||
// => {d_state, d_inner, n_tokens}
|
||||
ssm_state = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B);
|
||||
|
||||
// only store last state
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0,
|
||||
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]),
|
||||
ggml_view_tensor(ctx0, kv_self.v_l[il])));
|
||||
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tokens-1)*ssm_state->nb[2]),
|
||||
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner, kv_self.head*d_state*d_inner*ggml_element_size(ssm_state))));
|
||||
|
||||
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
|
||||
// {d_state, d_inner, n_tokens} * {d_state, n_tokens} => {d_inner, 1, n_tokens}
|
||||
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
|
||||
// => {d_inner, n_tok}
|
||||
// => {d_inner, n_tokens}
|
||||
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
|
||||
// {d_inner, n_tok} * {d_inner} => {d_inner, n_tok}
|
||||
// {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
|
||||
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
|
||||
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
|
||||
|
||||
// {d_inner, n_embd} * {d_inner, n_tok} => {n_embd, n_tok}
|
||||
// {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
|
||||
}
|
||||
|
||||
|
@ -8208,15 +8320,13 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
}
|
||||
|
||||
static void llama_set_k_shift(llama_context & lctx) {
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
const int64_t n_ctx = cparams.n_ctx;
|
||||
const int64_t kv_size = lctx.kv_self.size;
|
||||
|
||||
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
|
||||
|
||||
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
|
||||
|
||||
for (int i = 0; i < n_ctx; ++i) {
|
||||
for (int i = 0; i < kv_size; ++i) {
|
||||
data[i] = lctx.kv_self.cells[i].delta;
|
||||
}
|
||||
}
|
||||
|
@ -8257,6 +8367,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
|
||||
float * data = (float *) lctx.inp_KQ_mask->data;
|
||||
|
||||
// For Transformers, use only the previous KV cells
|
||||
// of the correct sequence for each token of the batch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const llama_pos pos = batch.pos[j];
|
||||
|
@ -8274,6 +8387,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
}
|
||||
}
|
||||
}
|
||||
// For Mamba (and other constant-time-and-size architectures),
|
||||
// update the correct state(s)/sequence(s) for each token of the batch.
|
||||
// Source and destination states are both the same for the sake of implementation simplicity.
|
||||
// It would be more complex if they were sometimes the same and somtimes not.
|
||||
// (with Transformers, source KV cells are never the destination,
|
||||
// which is also simpler, but more memory hungry)
|
||||
// TODO: implement
|
||||
}
|
||||
|
||||
if (hparams.need_kq_pos) {
|
||||
|
@ -8330,6 +8450,43 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (kv_self.unlimited) {
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
const uint32_t n_kv = kv_self.n;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
|
||||
float * data = (float *) lctx.inp_s_mask->data;
|
||||
|
||||
// states which are not affected by the current batch are left untouched
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
llama_seq_id seq_id = i + lctx.kv_self.head;
|
||||
llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
|
||||
bool has_self_seq = kv_cell.has_seq_id(seq_id);
|
||||
|
||||
data[i] = (float) has_self_seq;
|
||||
|
||||
// ensure current sequences will be kept
|
||||
if (!has_self_seq) {
|
||||
kv_cell.seq_id.insert(seq_id);
|
||||
}
|
||||
}
|
||||
// remove extraneous seq_ids when state copies are made
|
||||
{
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
llama_kv_cell & kv_cell = lctx.kv_self.cells[i];
|
||||
uint32_t n_seqs = kv_cell.seq_id.size();
|
||||
bool has_self_seq = kv_cell.has_seq_id(i);
|
||||
|
||||
if (has_self_seq && n_seqs > 1) {
|
||||
kv_cell.seq_id.clear();
|
||||
kv_cell.seq_id.insert(i);
|
||||
} else if (!has_self_seq && n_seqs > 0) {
|
||||
kv_cell.seq_id.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_graph_compute(
|
||||
|
@ -8450,6 +8607,7 @@ static int llama_decode_internal(
|
|||
return 1;
|
||||
}
|
||||
|
||||
if (!kv_self.unlimited) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
|
@ -8457,6 +8615,7 @@ static int llama_decode_internal(
|
|||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
|
@ -8817,7 +8976,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
|
||||
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
// apply K-shift if needed
|
||||
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
|
||||
if ((lctx.kv_self.unlimited || lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) && lctx.kv_self.has_shift) {
|
||||
llama_set_k_shift(lctx);
|
||||
|
||||
{
|
||||
|
@ -8832,7 +8991,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
|||
kv_self.has_shift = false;
|
||||
|
||||
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||
kv_self.cells[i].delta = 0;
|
||||
kv_self.cells[i].delta = kv_self.unlimited ? i : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -12122,6 +12281,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
||||
/*.n_ctx =*/ 512,
|
||||
/*.n_batch =*/ 512,
|
||||
/*.n_parallel =*/ 1,
|
||||
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
|
||||
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
||||
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
||||
|
@ -12283,6 +12443,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
auto & cparams = ctx->cparams;
|
||||
|
||||
cparams.n_batch = params.n_batch;
|
||||
// TODO: maybe add n_parallel here too
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
|
@ -12339,14 +12500,19 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->rng = std::mt19937(params.seed);
|
||||
ctx->logits_all = params.logits_all;
|
||||
|
||||
uint32_t kv_size = cparams.n_ctx;
|
||||
ggml_type type_k = params.type_k;
|
||||
ggml_type type_v = params.type_v;
|
||||
|
||||
// Mamba (mis)uses the KV cache to store its states
|
||||
// Mamba only needs a constant number of KV cache slots per sequence
|
||||
if (model->arch == LLM_ARCH_MAMBA) {
|
||||
// Mamba needs as many slots as there are distinct sequences processed at the same time
|
||||
// The extra slot allows dedicating a sequence id to the system prompt
|
||||
// TODO: find a better way to get the max number of parallel sequences
|
||||
kv_size = params.n_parallel + 1;
|
||||
// it's probably best to keep as much precision as possible for the states
|
||||
type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state
|
||||
type_v = GGML_TYPE_F32; // required by ggml_mul for Mamba's ssm_state
|
||||
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_state
|
||||
}
|
||||
|
||||
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
||||
|
@ -12447,7 +12613,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
}
|
||||
ctx->backends.push_back(ctx->backend_cpu);
|
||||
|
||||
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) {
|
||||
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
|
@ -12481,7 +12647,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
// graph inputs
|
||||
{
|
||||
ggml_init_params init_params = {
|
||||
/* .mem_size */ ggml_tensor_overhead()*8,
|
||||
/* .mem_size */ ggml_tensor_overhead()*(8 + ctx->kv_self.unlimited),
|
||||
/* .mem_buffer */ nullptr,
|
||||
/* .no_alloc */ true,
|
||||
};
|
||||
|
@ -12490,11 +12656,13 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||
ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
|
||||
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
|
||||
ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
|
||||
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
|
||||
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, kv_size, cparams.n_batch);
|
||||
ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
|
||||
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
|
||||
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
|
||||
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||
if (ctx->kv_self.unlimited)
|
||||
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
|
||||
|
||||
ggml_set_name(ctx->inp_tokens, "inp_tokens");
|
||||
ggml_set_name(ctx->inp_embd, "inp_embd");
|
||||
|
@ -12504,6 +12672,8 @@ struct llama_context * llama_new_context_with_model(
|
|||
ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
|
||||
ggml_set_name(ctx->inp_mean, "inp_mean");
|
||||
ggml_set_name(ctx->inp_cls, "inp_cls");
|
||||
if (ctx->kv_self.unlimited)
|
||||
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
|
||||
|
||||
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
|
||||
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
|
||||
|
|
1
llama.h
1
llama.h
|
@ -235,6 +235,7 @@ extern "C" {
|
|||
uint32_t seed; // RNG seed, -1 for random
|
||||
uint32_t n_ctx; // text context, 0 = from model
|
||||
uint32_t n_batch; // prompt processing maximum batch size
|
||||
uint32_t n_parallel; // number of parallel sequences
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue