mamba : begin figuring out how to (ab)use the kv cache for Mamba
This commit is contained in:
parent
8cd0a286b4
commit
5a69a262a1
2 changed files with 80 additions and 35 deletions
|
@ -1849,9 +1849,13 @@ class MambaModel(Model):
|
|||
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
d_model = self.hparams["d_model"]
|
||||
self.gguf_writer.add_name(self.dir_model.name)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_embedding_length(d_model)
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
self.gguf_writer.add_head_count(2 * d_model) # d_inner
|
||||
self.gguf_writer.add_key_length(4) # d_conv
|
||||
self.gguf_writer.add_value_length(16) # d_state
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
|
||||
|
|
95
llama.cpp
95
llama.cpp
|
@ -2067,6 +2067,14 @@ static bool llama_kv_cache_init(
|
|||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
|
||||
if (model.arch == LLM_ARCH_MAMBA) {
|
||||
// only one slot is needed for Mamba
|
||||
n_ctx = 1;
|
||||
// it's probably best to keep as much precision as possible for the states
|
||||
ktype = GGML_TYPE_F32;
|
||||
vtype = GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
cache.has_shift = false;
|
||||
|
||||
cache.head = 0;
|
||||
|
@ -2151,6 +2159,12 @@ static bool llama_kv_cache_find_slot(
|
|||
const uint32_t n_ctx = cache.size;
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
|
||||
// for Mamba and/or other model archs that only ever use one slot
|
||||
if (n_ctx == 1) {
|
||||
// hopefully no one actually uses a context size of 1 on Transformer-based models
|
||||
return true;
|
||||
}
|
||||
|
||||
if (n_tokens > n_ctx) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
|
||||
return false;
|
||||
|
@ -4666,12 +4680,18 @@ static bool llm_load_tensors(
|
|||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
const int64_t d_conv = hparams.n_embd_head_k;
|
||||
const int64_t d_state = hparams.n_embd_head_v;
|
||||
const int64_t d_inner = hparams.n_head;
|
||||
// FIXME: ceiling instead of floor
|
||||
const int64_t dt_rank = n_embd / 16;
|
||||
GGML_ASSERT(2 * n_embd == d_inner);
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||
}
|
||||
// TODO: MAMBA
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
|
@ -4679,19 +4699,30 @@ static bool llm_load_tensors(
|
|||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
// norm
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
// TODO: D, in_proj, conv1d, x_proj, dt_proj, A_log, out_proj
|
||||
// TODO: what's the difference between ctx_layer and ctx_split?
|
||||
// A: It seems that ctx_split is for matrices (2d???) while ctx_layer is for other things (like 1D bias and norms, probably.)
|
||||
|
||||
// norm
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
|
||||
|
||||
layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, 1, d_inner});
|
||||
layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
|
||||
|
||||
layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
|
||||
|
||||
layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner});
|
||||
layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
|
||||
|
||||
// FIXME: maybe no suffix for these
|
||||
layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, "weight", i), {d_state, d_inner});
|
||||
layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, "weight", i), {d_inner});
|
||||
|
||||
// out_proj
|
||||
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {2*n_embd, n_embd});
|
||||
|
||||
|
||||
}
|
||||
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
@ -5272,7 +5303,7 @@ struct llm_build_context {
|
|||
norm_eps (hparams.f_norm_eps),
|
||||
norm_rms_eps (hparams.f_norm_rms_eps),
|
||||
n_tokens (batch.n_tokens),
|
||||
n_kv (worst_case ? n_ctx : kv_self.n),
|
||||
n_kv (worst_case ? kv_self.size : kv_self.n),
|
||||
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
|
||||
n_orig_ctx (cparams.n_yarn_orig_ctx),
|
||||
pooling_type (cparams.pooling_type),
|
||||
|
@ -7876,28 +7907,30 @@ struct llm_build_context {
|
|||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_mamba() {
|
||||
struct ggml_cgraph * build_mamba(bool use_conv) {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
// d_model
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t d_state = 16;
|
||||
const int64_t d_conv = 4;
|
||||
// expand = 2
|
||||
// d_inner = expand * d_model
|
||||
const int64_t d_inner = 2 * n_embd; // FIXME: this is wrong
|
||||
GGML_ASSERT(use_conv == false); // TODO: implement
|
||||
|
||||
const int64_t d_model = hparams.n_embd;
|
||||
const int64_t d_inner = hparams.n_head;
|
||||
GGML_ASSERT(2 * d_model == d_inner);
|
||||
const int64_t d_conv = hparams.n_embd_head_k;
|
||||
const int64_t d_state = hparams.n_embd_head_v;
|
||||
const int64_t dt_rank = d_model / 16;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
// TODO: give it the right size
|
||||
struct ggml_tensor * state;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
|
||||
// {n_embd, batch}
|
||||
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
|
||||
cb(inpL, "inp_embd", -1);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// FIXME: init attn_norm
|
||||
// (ab)using the kv cache to store the state
|
||||
ggml_tensor * conv_state = kv_self.k_l[il]; // {d_conv, d_inner}
|
||||
ggml_tensor * ssm_state = kv_self.v_l[il]; // {d_state, d_inner}
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
|
@ -7905,14 +7938,18 @@ struct llm_build_context {
|
|||
// TODO: that's probably the wrong name.
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// conv
|
||||
{
|
||||
// [] * [] = [2*n_embd]
|
||||
// {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner}
|
||||
struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in);
|
||||
// split the above in two
|
||||
struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0);
|
||||
struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, d_inner);
|
||||
|
||||
// FIXME: figure out when to transpose
|
||||
// conv
|
||||
{
|
||||
// TODO: figure out how to do a row-wise dot product
|
||||
// TODO: use the kv-cache to store the state
|
||||
kv_self.k_l[il];
|
||||
|
||||
// FIXME: this is wrong
|
||||
cur = ggml_conv_1d(ctx0, cur, model.layers[il].ssm_conv1d, 1, d_conv - 1, 1);
|
||||
|
@ -8111,6 +8148,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
{
|
||||
result = llm.build_starcoder2();
|
||||
} break;
|
||||
case LLM_ARCH_MAMBA:
|
||||
{
|
||||
result = llm.build_mamba(/* use_conv =*/ batch.n_tokens > 1);
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
@ -8366,7 +8407,7 @@ static int llama_decode_internal(
|
|||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
|
||||
kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue