mamba : begin figuring out how to (ab)use the kv cache for Mamba

This commit is contained in:
Francis Couture-Harpin 2024-01-27 11:41:20 -05:00
parent 8cd0a286b4
commit 5a69a262a1
2 changed files with 80 additions and 35 deletions

View file

@ -1849,9 +1849,13 @@ class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA model_arch = gguf.MODEL_ARCH.MAMBA
def set_gguf_parameters(self): def set_gguf_parameters(self):
d_model = self.hparams["d_model"]
self.gguf_writer.add_name(self.dir_model.name) self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_head_count(2 * d_model) # d_inner
self.gguf_writer.add_key_length(4) # d_conv
self.gguf_writer.add_value_length(16) # d_state
self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_file_type(self.ftype)

101
llama.cpp
View file

@ -2067,6 +2067,14 @@ static bool llama_kv_cache_init(
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const int64_t n_layer = hparams.n_layer; const int64_t n_layer = hparams.n_layer;
if (model.arch == LLM_ARCH_MAMBA) {
// only one slot is needed for Mamba
n_ctx = 1;
// it's probably best to keep as much precision as possible for the states
ktype = GGML_TYPE_F32;
vtype = GGML_TYPE_F32;
}
cache.has_shift = false; cache.has_shift = false;
cache.head = 0; cache.head = 0;
@ -2151,6 +2159,12 @@ static bool llama_kv_cache_find_slot(
const uint32_t n_ctx = cache.size; const uint32_t n_ctx = cache.size;
const uint32_t n_tokens = batch.n_tokens; const uint32_t n_tokens = batch.n_tokens;
// for Mamba and/or other model archs that only ever use one slot
if (n_ctx == 1) {
// hopefully no one actually uses a context size of 1 on Transformer-based models
return true;
}
if (n_tokens > n_ctx) { if (n_tokens > n_ctx) {
LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
return false; return false;
@ -4666,12 +4680,18 @@ static bool llm_load_tensors(
{ {
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
const int64_t d_conv = hparams.n_embd_head_k;
const int64_t d_state = hparams.n_embd_head_v;
const int64_t d_inner = hparams.n_head;
// FIXME: ceiling instead of floor
const int64_t dt_rank = n_embd / 16;
GGML_ASSERT(2 * n_embd == d_inner);
// output // output
{ {
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
} }
// TODO: MAMBA
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_layer = ctx_for_layer(i);
@ -4679,19 +4699,30 @@ static bool llm_load_tensors(
auto & layer = model.layers[i]; auto & layer = model.layers[i];
// norm
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
// TODO: D, in_proj, conv1d, x_proj, dt_proj, A_log, out_proj
// TODO: what's the difference between ctx_layer and ctx_split? // TODO: what's the difference between ctx_layer and ctx_split?
// A: It seems that ctx_split is for matrices (2d???) while ctx_layer is for other things (like 1D bias and norms, probably.) // A: It seems that ctx_split is for matrices (2d???) while ctx_layer is for other things (like 1D bias and norms, probably.)
// norm
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, 1, d_inner});
layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner});
layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
// FIXME: maybe no suffix for these
layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, "weight", i), {d_state, d_inner});
layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, "weight", i), {d_inner});
// out_proj // out_proj
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {2*n_embd, n_embd}); layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
} }
} } break;
default: default:
throw std::runtime_error("unknown architecture"); throw std::runtime_error("unknown architecture");
} }
@ -5272,7 +5303,7 @@ struct llm_build_context {
norm_eps (hparams.f_norm_eps), norm_eps (hparams.f_norm_eps),
norm_rms_eps (hparams.f_norm_rms_eps), norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (batch.n_tokens), n_tokens (batch.n_tokens),
n_kv (worst_case ? n_ctx : kv_self.n), n_kv (worst_case ? kv_self.size : kv_self.n),
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx), n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.pooling_type), pooling_type (cparams.pooling_type),
@ -7876,28 +7907,30 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_mamba() { struct ggml_cgraph * build_mamba(bool use_conv) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
// d_model GGML_ASSERT(use_conv == false); // TODO: implement
const int64_t n_embd = hparams.n_embd;
const int64_t d_state = 16; const int64_t d_model = hparams.n_embd;
const int64_t d_conv = 4; const int64_t d_inner = hparams.n_head;
// expand = 2 GGML_ASSERT(2 * d_model == d_inner);
// d_inner = expand * d_model const int64_t d_conv = hparams.n_embd_head_k;
const int64_t d_inner = 2 * n_embd; // FIXME: this is wrong const int64_t d_state = hparams.n_embd_head_v;
const int64_t dt_rank = d_model / 16;
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
// TODO: give it the right size // {n_embd, batch}
struct ggml_tensor * state; inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
cb(inpL, "inp_embd", -1); cb(inpL, "inp_embd", -1);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
// FIXME: init attn_norm // (ab)using the kv cache to store the state
ggml_tensor * conv_state = kv_self.k_l[il]; // {d_conv, d_inner}
ggml_tensor * ssm_state = kv_self.v_l[il]; // {d_state, d_inner}
// norm // norm
cur = llm_build_norm(ctx0, inpL, hparams, cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL, model.layers[il].attn_norm, NULL,
@ -7905,14 +7938,18 @@ struct llm_build_context {
// TODO: that's probably the wrong name. // TODO: that's probably the wrong name.
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
// {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner}
struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in);
// split the above in two
struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0);
struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, d_inner);
// FIXME: figure out when to transpose
// conv // conv
{ {
// [] * [] = [2*n_embd] // TODO: figure out how to do a row-wise dot product
struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in); // TODO: use the kv-cache to store the state
// split the above in two kv_self.k_l[il];
struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0);
struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, d_inner);
// FIXME: this is wrong // FIXME: this is wrong
cur = ggml_conv_1d(ctx0, cur, model.layers[il].ssm_conv1d, 1, d_conv - 1, 1); cur = ggml_conv_1d(ctx0, cur, model.layers[il].ssm_conv1d, 1, d_conv - 1, 1);
@ -8111,6 +8148,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_starcoder2(); result = llm.build_starcoder2();
} break; } break;
case LLM_ARCH_MAMBA:
{
result = llm.build_mamba(/* use_conv =*/ batch.n_tokens > 1);
} break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
} }
@ -8366,7 +8407,7 @@ static int llama_decode_internal(
// a heuristic, to avoid attending the full cache if it is not yet utilized // a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears // after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important // if we start defragmenting the cache, the benefit from this will be more important
kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
//kv_self.n = llama_kv_cache_cell_max(kv_self); //kv_self.n = llama_kv_cache_cell_max(kv_self);
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);