diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 28e865e5c..fab409b08 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1849,9 +1849,13 @@ class MambaModel(Model): model_arch = gguf.MODEL_ARCH.MAMBA def set_gguf_parameters(self): + d_model = self.hparams["d_model"] self.gguf_writer.add_name(self.dir_model.name) - self.gguf_writer.add_embedding_length(self.hparams["d_model"]) + self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_head_count(2 * d_model) # d_inner + self.gguf_writer.add_key_length(4) # d_conv + self.gguf_writer.add_value_length(16) # d_state self.gguf_writer.add_file_type(self.ftype) diff --git a/llama.cpp b/llama.cpp index 2dc1cc1b3..c46f669e3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1765,7 +1765,7 @@ struct llama_layer { struct ggml_tensor * ffn_up_b; // b3 struct ggml_tensor * ffn_act; - + // mamba proj struct ggml_tensor * ssm_in; struct ggml_tensor * ssm_x; @@ -2067,6 +2067,14 @@ static bool llama_kv_cache_init( const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_layer = hparams.n_layer; + if (model.arch == LLM_ARCH_MAMBA) { + // only one slot is needed for Mamba + n_ctx = 1; + // it's probably best to keep as much precision as possible for the states + ktype = GGML_TYPE_F32; + vtype = GGML_TYPE_F32; + } + cache.has_shift = false; cache.head = 0; @@ -2151,6 +2159,12 @@ static bool llama_kv_cache_find_slot( const uint32_t n_ctx = cache.size; const uint32_t n_tokens = batch.n_tokens; + // for Mamba and/or other model archs that only ever use one slot + if (n_ctx == 1) { + // hopefully no one actually uses a context size of 1 on Transformer-based models + return true; + } + if (n_tokens > n_ctx) { LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); return false; @@ -4665,13 +4679,19 @@ static bool llm_load_tensors( case LLM_ARCH_MAMBA: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - + + const int64_t d_conv = hparams.n_embd_head_k; + const int64_t d_state = hparams.n_embd_head_v; + const int64_t d_inner = hparams.n_head; + // FIXME: ceiling instead of floor + const int64_t dt_rank = n_embd / 16; + GGML_ASSERT(2 * n_embd == d_inner); + // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } - // TODO: MAMBA for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -4679,19 +4699,30 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - // norm - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - - // TODO: D, in_proj, conv1d, x_proj, dt_proj, A_log, out_proj // TODO: what's the difference between ctx_layer and ctx_split? // A: It seems that ctx_split is for matrices (2d???) while ctx_layer is for other things (like 1D bias and norms, probably.) - // out_proj - layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {2*n_embd, n_embd}); + // norm + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - + layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); + + layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, 1, d_inner}); + layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); + + layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); + + layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}); + layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); + + // FIXME: maybe no suffix for these + layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, "weight", i), {d_state, d_inner}); + layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, "weight", i), {d_inner}); + + // out_proj + layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); } - } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -5272,7 +5303,7 @@ struct llm_build_context { norm_eps (hparams.f_norm_eps), norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (batch.n_tokens), - n_kv (worst_case ? n_ctx : kv_self.n), + n_kv (worst_case ? kv_self.size : kv_self.n), kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), pooling_type (cparams.pooling_type), @@ -7876,28 +7907,30 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_mamba() { + struct ggml_cgraph * build_mamba(bool use_conv) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - // d_model - const int64_t n_embd = hparams.n_embd; - const int64_t d_state = 16; - const int64_t d_conv = 4; - // expand = 2 - // d_inner = expand * d_model - const int64_t d_inner = 2 * n_embd; // FIXME: this is wrong + GGML_ASSERT(use_conv == false); // TODO: implement + + const int64_t d_model = hparams.n_embd; + const int64_t d_inner = hparams.n_head; + GGML_ASSERT(2 * d_model == d_inner); + const int64_t d_conv = hparams.n_embd_head_k; + const int64_t d_state = hparams.n_embd_head_v; + const int64_t dt_rank = d_model / 16; struct ggml_tensor * cur; struct ggml_tensor * inpL; - // TODO: give it the right size - struct ggml_tensor * state; - - inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + // {n_embd, batch} + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); cb(inpL, "inp_embd", -1); for (int il = 0; il < n_layer; ++il) { - // FIXME: init attn_norm + // (ab)using the kv cache to store the state + ggml_tensor * conv_state = kv_self.k_l[il]; // {d_conv, d_inner} + ggml_tensor * ssm_state = kv_self.v_l[il]; // {d_state, d_inner} + // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, @@ -7905,15 +7938,19 @@ struct llm_build_context { // TODO: that's probably the wrong name. cb(cur, "attn_norm", il); + // {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner} + struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in); + // split the above in two + struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0); + struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, d_inner); + + // FIXME: figure out when to transpose // conv { - // [] * [] = [2*n_embd] - struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in); - // split the above in two - struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0); - struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, d_inner); + // TODO: figure out how to do a row-wise dot product + // TODO: use the kv-cache to store the state + kv_self.k_l[il]; - // FIXME: this is wrong cur = ggml_conv_1d(ctx0, cur, model.layers[il].ssm_conv1d, 1, d_conv - 1, 1); @@ -7925,9 +7962,9 @@ struct llm_build_context { // ssm { - + // TODO: use ggml_soft_plus here - + } // TODO: there's some SiLU again towards the end. Can the `llm_build_ffn` helper be used? @@ -8111,6 +8148,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_starcoder2(); } break; + case LLM_ARCH_MAMBA: + { + result = llm.build_mamba(/* use_conv =*/ batch.n_tokens > 1); + } break; default: GGML_ASSERT(false); } @@ -8366,7 +8407,7 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); //kv_self.n = llama_kv_cache_cell_max(kv_self); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);