diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index de7bf431f..bed830ce6 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1857,7 +1857,14 @@ class MambaModel(Model): def set_gguf_parameters(self): d_model = self.hparams["d_model"] + d_conv = self.hparams.get("d_conv", 4) d_inner = self.hparams.get("d_inner", 2 * d_model) + d_state = self.hparams.get("d_state", 16) + # ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 + dt_rank = self.hparams.get("dt_rank", -(d_model // -16)) + # Fail early for models which don't have a block expansion factor of 2 assert d_inner == 2 * d_model @@ -1865,13 +1872,13 @@ class MambaModel(Model): self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading - self.gguf_writer.add_head_count(d_inner) # the number of rows in conv_state and ssm_state + self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_ssm_conv_kernel_size(d_conv) + self.gguf_writer.add_ssm_inner_length(d_inner) + self.gguf_writer.add_ssm_state_length(d_state) + self.gguf_writer.add_ssm_dt_rank(dt_rank) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5)) - # NOTE: (ab)using the KV cache metadata to store dimensions for conv_state and ssm_state - # Since the first column of the conv_state is shifted out each time, it's not actually needed - self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4) - 1) - self.gguf_writer.add_value_length(self.hparams.get("d_state", 16)) self.gguf_writer.add_file_type(self.ftype) def write_tensors(self): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 651323a1e..8030023f3 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -61,6 +61,12 @@ class Keys: SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + class SSM: + CONV_KERNEL_SIZE = "{arch}.ssm.d_conv" + INNER_LENGTH = "{arch}.ssm.d_inner" + STATE_LENGTH = "{arch}.ssm.d_state" + DT_RANK = "{arch}.ssm.dt_rank" + class Tokenizer: MODEL = "tokenizer.ggml.model" LIST = "tokenizer.ggml.tokens" @@ -763,6 +769,12 @@ KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED +# SSM +KEY_SSM_CONV_KERNEL_SIZE = Keys.SSM.CONV_KERNEL_SIZE +KEY_SSM_INNER_LENGTH = Keys.SSM.INNER_LENGTH +KEY_SSM_STATE_LENGTH = Keys.SSM.STATE_LENGTH +KEY_SSM_DT_RANK = Keys.SSM.DT_RANK + # tokenization KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 801160832..146358e69 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -382,6 +382,18 @@ class GGUFWriter: def add_rope_scaling_finetuned(self, value: bool) -> None: self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value) + def add_ssm_conv_kernel_size(self, value: int) -> None: + self.add_uint32(Keys.SSM.CONV_KERNEL_SIZE.format(arch=self.arch), value) + + def add_ssm_inner_length(self, value: int) -> None: + self.add_uint32(Keys.SSM.INNER_LENGTH.format(arch=self.arch), value) + + def add_ssm_state_length(self, value: int) -> None: + self.add_uint32(Keys.SSM.STATE_LENGTH.format(arch=self.arch), value) + + def add_ssm_dt_rank(self, value: int) -> None: + self.add_uint32(Keys.SSM.DT_RANK.format(arch=self.arch), value) + def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) diff --git a/llama.cpp b/llama.cpp index f437059b2..eb1f02e42 100644 --- a/llama.cpp +++ b/llama.cpp @@ -286,6 +286,11 @@ enum llm_kv { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, + LLM_KV_SSM_D_INNER, + LLM_KV_SSM_D_CONV, + LLM_KV_SSM_D_STATE, + LLM_KV_SSM_DT_RANK, + LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_TOKEN_TYPE, @@ -344,6 +349,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_SSM_D_CONV, "%s.ssm.d_conv" }, + { LLM_KV_SSM_D_INNER, "%s.ssm.d_inner"}, + { LLM_KV_SSM_D_STATE, "%s.ssm.d_state"}, + { LLM_KV_SSM_DT_RANK, "%s.ssm.dt_rank"}, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, @@ -1638,6 +1648,12 @@ struct llama_hparams { float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; + // for State Space Models + uint32_t ssm_d_conv = 0; + uint32_t ssm_d_inner = 0; + uint32_t ssm_d_state = 0; + uint32_t ssm_dt_rank = 0; + float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; @@ -1666,6 +1682,11 @@ struct llama_hparams { if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; + if (this->ssm_d_conv != other.ssm_d_conv) return true; + if (this->ssm_d_inner != other.ssm_d_inner) return true; + if (this->ssm_d_state != other.ssm_d_state) return true; + if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + const float EPSILON = 1e-9f; if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; @@ -1677,6 +1698,9 @@ struct llama_hparams { } uint32_t n_gqa() const { + if (n_head_kv == 0) { + return 0; + } return n_head/n_head_kv; } @@ -1687,6 +1711,18 @@ struct llama_hparams { uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads return n_embd_head_v * n_head_kv; } + + uint32_t n_embd_k_s() const { // dimension of the recurrent convolution state embeddings + // corresponds to Mamba's conv_states size + // TODO: maybe support other convolution strides than 1 + // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + } + + uint32_t n_embd_v_s() const { // dimension of the ssm scan state embeddings + // corresponds to Mamba's ssm_states size + return ssm_d_state * ssm_d_inner; + } }; struct llama_cparams { @@ -1804,8 +1840,8 @@ struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; bool do_copy = false; - // with Mamba, a cell can hold the state for more than one past token - bool unlimited = false; + // with recurrent state models, a cell can hold the state for more than one past token + bool recurrent = false; // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -2067,14 +2103,21 @@ static bool llama_kv_cache_init( bool offload) { const struct llama_hparams & hparams = model.hparams; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); const int64_t n_layer = hparams.n_layer; cache.has_shift = false; - // for now, only Mamba can hold state for more than one past token per cell - cache.unlimited = model.arch == LLM_ARCH_MAMBA; + // TODO: find a nicer way to add other recurrent model architectures + cache.recurrent = model.arch == LLM_ARCH_MAMBA; + + // TODO: support mixed reccurent Transformer architectues + // NOTE: (!a || b) is a logical implication (a -> b) + GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s()); + GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s()); + GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa()); + GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa()); cache.head = 0; cache.size = kv_size; @@ -2086,7 +2129,8 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(kv_size); - if (cache.unlimited) { + if (cache.recurrent) { + // init state copy sources for (uint32_t i = 0; i < cache.size; ++i) { cache.cells[i].src = i; } @@ -2164,8 +2208,8 @@ static bool llama_kv_cache_find_slot( const uint32_t n_ctx = cache.size; const uint32_t n_tokens = batch.n_tokens; - if (cache.unlimited) { - // For unlimited context architectures (like Mamba), + if (cache.recurrent) { + // For recurrent state architectures (like Mamba), // each KV cache cell can store the state for a whole sequence. // starting point to find the minimum seq_id used in the batch @@ -2289,7 +2333,7 @@ static bool llama_kv_cache_seq_rm( if (p1 < 0) p1 = std::numeric_limits::max(); // models like Mamba can't have a state partially erased - if (cache.unlimited) { + if (cache.recurrent) { if (seq_id >= (int64_t) cache.size) { // could be fatal return false; @@ -2341,7 +2385,7 @@ static void llama_kv_cache_seq_cp( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.unlimited) { + if (cache.recurrent) { if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { seq_id_src = cache.cells[seq_id_src].src; GGML_ASSERT((uint32_t) seq_id_src < cache.size); @@ -2403,7 +2447,7 @@ static void llama_kv_cache_seq_add( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.unlimited) { + if (cache.recurrent) { // for Mamba-like models, only the pos needs to be shifted if (0 <= seq_id && seq_id < (int64_t) cache.size) { llama_kv_cell & cell = cache.cells[seq_id]; @@ -2447,7 +2491,7 @@ static void llama_kv_cache_seq_div( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.unlimited) { + if (cache.recurrent) { // for Mamba-like models, only the pos needs to be changed if (0 <= seq_id && seq_id < (int64_t) cache.size) { llama_kv_cell & cell = cache.cells[seq_id]; @@ -3277,7 +3321,7 @@ static void llm_load_hparams( // sanity check for n_rot (optional) { - hparams.n_rot = hparams.n_embd / hparams.n_head; + hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); @@ -3290,10 +3334,10 @@ static void llm_load_hparams( // gpt-j n_rot = rotary_dim } - hparams.n_embd_head_k = hparams.n_embd / hparams.n_head; + hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); - hparams.n_embd_head_v = hparams.n_embd / hparams.n_head; + hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); // arch-specific KVs @@ -3545,7 +3589,13 @@ static void llm_load_hparams( } break; case LLM_ARCH_MAMBA: { + ml.get_key(LLM_KV_SSM_D_CONV, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_D_INNER, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_D_STATE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_DT_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { case 24: switch (hparams.n_embd) { @@ -3886,6 +3936,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); if (ml.n_elements >= 1e12) { @@ -4050,10 +4104,7 @@ static bool llm_load_tensors( const int64_t n_vocab_type = hparams.n_vocab_type; const int64_t n_ff = hparams.n_ff; - // Mamba uses these in its own way - if (model.arch != LLM_ARCH_MAMBA) { - GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); - } + GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); ggml_context * ctx_input = ctx_map.at(model.buft_input.buft); ggml_context * ctx_output = ctx_map.at(model.buft_output.buft); @@ -4792,12 +4843,11 @@ static bool llm_load_tensors( } break; case LLM_ARCH_MAMBA: { - const int64_t d_conv = hparams.n_embd_head_k + 1; - const int64_t d_state = hparams.n_embd_head_v; - const int64_t d_inner = hparams.n_head; - // TODO: allow loading dt_rank from the model config - // ceiling division - const int64_t dt_rank = (n_embd / 16) + (n_embd % 16 > 0); + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + // only an expansion factor of 2 is supported for now GGML_ASSERT(2 * n_embd == d_inner); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -5420,7 +5470,7 @@ struct llm_build_context { norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), - kv_head (worst_case ? (kv_self.unlimited ? 0 : kv_self.size - n_tokens) : kv_self.head), + kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), @@ -5473,8 +5523,8 @@ struct llm_build_context { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); for (int il = 0; il < n_layer; ++il) { - ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size); - ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size); + ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); + ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy); ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy); @@ -8048,12 +8098,11 @@ struct llm_build_context { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); const int64_t d_model = n_embd; - const int64_t d_inner = n_head; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; GGML_ASSERT(2 * d_model == d_inner); - const int64_t d_conv = n_embd_head_k + 1; - const int64_t d_state = n_embd_head_v; - // ceiling division - const int64_t dt_rank = (d_model / 16) + (d_model % 16 > 0); + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -8063,10 +8112,9 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); for (int il = 0; il < n_layer; ++il) { - // (ab)using the kv cache to store the state - // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], (d_conv-1)*(d_inner), kv_self.size); - ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], (d_state)*(d_inner), kv_self.size); + // (ab)using the KV cache to store the states + ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); + ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); // clear states of sequences which are starting at the beginning of this batch { @@ -8501,7 +8549,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (kv_self.unlimited) { + if (kv_self.recurrent) { const int64_t n_kv = kv_self.n; { @@ -8667,7 +8715,7 @@ static int llama_decode_internal( return 1; } - if (!kv_self.unlimited) { + if (!kv_self.recurrent) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important @@ -9056,7 +9104,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } - if (lctx.kv_self.unlimited && lctx.kv_self.do_copy) { + if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { llama_set_s_copy(lctx); { @@ -12725,7 +12773,7 @@ struct llama_context * llama_new_context_with_model( // graph inputs { ggml_init_params init_params = { - /* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.unlimited)), + /* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.recurrent)), /* .mem_buffer */ nullptr, /* .no_alloc */ true, }; @@ -12739,7 +12787,7 @@ struct llama_context * llama_new_context_with_model( ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size); ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); - if (ctx->kv_self.unlimited) { + if (ctx->kv_self.recurrent) { ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size); ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size); ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch); @@ -12753,7 +12801,7 @@ struct llama_context * llama_new_context_with_model( ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); ggml_set_name(ctx->inp_mean, "inp_mean"); ggml_set_name(ctx->inp_cls, "inp_cls"); - if (ctx->kv_self.unlimited) { + if (ctx->kv_self.recurrent) { ggml_set_name(ctx->inp_s_copy, "inp_s_copy"); ggml_set_name(ctx->inp_s_mask, "inp_s_mask"); ggml_set_name(ctx->inp_s_seq, "inp_s_seq");