From 206e8ee2b2c540e67f6631989fc677bcc69e589a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 28 Feb 2024 10:58:17 -0500 Subject: [PATCH] mamba : stop abusing attention metadata This breaks existing converted-to-GGUF Mamba models, but will allow supporting mixed architectures like MambaFormer without needing to break Mamba models. This will also allow changing the size of Mamba's states without having to reconvert models in the future. (e.g. using something else than d_conv - 1 columns for the conv_states will not require breaking existing converted Mamba models again) * gguf-py : add new KV metadata key-value pairs for Mamba * llama : add new metadata key-value pairs for Mamba * llama : guard against divisions by zero when n_head is 0 * mamba : rename "unlimited" KV cache property to "recurrent" --- convert-hf-to-gguf.py | 17 +++-- gguf-py/gguf/constants.py | 12 ++++ gguf-py/gguf/gguf_writer.py | 12 ++++ llama.cpp | 136 ++++++++++++++++++++++++------------ 4 files changed, 128 insertions(+), 49 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index de7bf431f..bed830ce6 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1857,7 +1857,14 @@ class MambaModel(Model): def set_gguf_parameters(self): d_model = self.hparams["d_model"] + d_conv = self.hparams.get("d_conv", 4) d_inner = self.hparams.get("d_inner", 2 * d_model) + d_state = self.hparams.get("d_state", 16) + # ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 + dt_rank = self.hparams.get("dt_rank", -(d_model // -16)) + # Fail early for models which don't have a block expansion factor of 2 assert d_inner == 2 * d_model @@ -1865,13 +1872,13 @@ class MambaModel(Model): self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading - self.gguf_writer.add_head_count(d_inner) # the number of rows in conv_state and ssm_state + self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_ssm_conv_kernel_size(d_conv) + self.gguf_writer.add_ssm_inner_length(d_inner) + self.gguf_writer.add_ssm_state_length(d_state) + self.gguf_writer.add_ssm_dt_rank(dt_rank) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5)) - # NOTE: (ab)using the KV cache metadata to store dimensions for conv_state and ssm_state - # Since the first column of the conv_state is shifted out each time, it's not actually needed - self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4) - 1) - self.gguf_writer.add_value_length(self.hparams.get("d_state", 16)) self.gguf_writer.add_file_type(self.ftype) def write_tensors(self): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 651323a1e..8030023f3 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -61,6 +61,12 @@ class Keys: SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + class SSM: + CONV_KERNEL_SIZE = "{arch}.ssm.d_conv" + INNER_LENGTH = "{arch}.ssm.d_inner" + STATE_LENGTH = "{arch}.ssm.d_state" + DT_RANK = "{arch}.ssm.dt_rank" + class Tokenizer: MODEL = "tokenizer.ggml.model" LIST = "tokenizer.ggml.tokens" @@ -763,6 +769,12 @@ KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED +# SSM +KEY_SSM_CONV_KERNEL_SIZE = Keys.SSM.CONV_KERNEL_SIZE +KEY_SSM_INNER_LENGTH = Keys.SSM.INNER_LENGTH +KEY_SSM_STATE_LENGTH = Keys.SSM.STATE_LENGTH +KEY_SSM_DT_RANK = Keys.SSM.DT_RANK + # tokenization KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 801160832..146358e69 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -382,6 +382,18 @@ class GGUFWriter: def add_rope_scaling_finetuned(self, value: bool) -> None: self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value) + def add_ssm_conv_kernel_size(self, value: int) -> None: + self.add_uint32(Keys.SSM.CONV_KERNEL_SIZE.format(arch=self.arch), value) + + def add_ssm_inner_length(self, value: int) -> None: + self.add_uint32(Keys.SSM.INNER_LENGTH.format(arch=self.arch), value) + + def add_ssm_state_length(self, value: int) -> None: + self.add_uint32(Keys.SSM.STATE_LENGTH.format(arch=self.arch), value) + + def add_ssm_dt_rank(self, value: int) -> None: + self.add_uint32(Keys.SSM.DT_RANK.format(arch=self.arch), value) + def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) diff --git a/llama.cpp b/llama.cpp index f437059b2..eb1f02e42 100644 --- a/llama.cpp +++ b/llama.cpp @@ -286,6 +286,11 @@ enum llm_kv { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, + LLM_KV_SSM_D_INNER, + LLM_KV_SSM_D_CONV, + LLM_KV_SSM_D_STATE, + LLM_KV_SSM_DT_RANK, + LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_TOKEN_TYPE, @@ -344,6 +349,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_SSM_D_CONV, "%s.ssm.d_conv" }, + { LLM_KV_SSM_D_INNER, "%s.ssm.d_inner"}, + { LLM_KV_SSM_D_STATE, "%s.ssm.d_state"}, + { LLM_KV_SSM_DT_RANK, "%s.ssm.dt_rank"}, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, @@ -1638,6 +1648,12 @@ struct llama_hparams { float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; + // for State Space Models + uint32_t ssm_d_conv = 0; + uint32_t ssm_d_inner = 0; + uint32_t ssm_d_state = 0; + uint32_t ssm_dt_rank = 0; + float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; @@ -1666,6 +1682,11 @@ struct llama_hparams { if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; + if (this->ssm_d_conv != other.ssm_d_conv) return true; + if (this->ssm_d_inner != other.ssm_d_inner) return true; + if (this->ssm_d_state != other.ssm_d_state) return true; + if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + const float EPSILON = 1e-9f; if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; @@ -1677,6 +1698,9 @@ struct llama_hparams { } uint32_t n_gqa() const { + if (n_head_kv == 0) { + return 0; + } return n_head/n_head_kv; } @@ -1687,6 +1711,18 @@ struct llama_hparams { uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads return n_embd_head_v * n_head_kv; } + + uint32_t n_embd_k_s() const { // dimension of the recurrent convolution state embeddings + // corresponds to Mamba's conv_states size + // TODO: maybe support other convolution strides than 1 + // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + } + + uint32_t n_embd_v_s() const { // dimension of the ssm scan state embeddings + // corresponds to Mamba's ssm_states size + return ssm_d_state * ssm_d_inner; + } }; struct llama_cparams { @@ -1804,8 +1840,8 @@ struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; bool do_copy = false; - // with Mamba, a cell can hold the state for more than one past token - bool unlimited = false; + // with recurrent state models, a cell can hold the state for more than one past token + bool recurrent = false; // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -2067,14 +2103,21 @@ static bool llama_kv_cache_init( bool offload) { const struct llama_hparams & hparams = model.hparams; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); const int64_t n_layer = hparams.n_layer; cache.has_shift = false; - // for now, only Mamba can hold state for more than one past token per cell - cache.unlimited = model.arch == LLM_ARCH_MAMBA; + // TODO: find a nicer way to add other recurrent model architectures + cache.recurrent = model.arch == LLM_ARCH_MAMBA; + + // TODO: support mixed reccurent Transformer architectues + // NOTE: (!a || b) is a logical implication (a -> b) + GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s()); + GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s()); + GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa()); + GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa()); cache.head = 0; cache.size = kv_size; @@ -2086,7 +2129,8 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(kv_size); - if (cache.unlimited) { + if (cache.recurrent) { + // init state copy sources for (uint32_t i = 0; i < cache.size; ++i) { cache.cells[i].src = i; } @@ -2164,8 +2208,8 @@ static bool llama_kv_cache_find_slot( const uint32_t n_ctx = cache.size; const uint32_t n_tokens = batch.n_tokens; - if (cache.unlimited) { - // For unlimited context architectures (like Mamba), + if (cache.recurrent) { + // For recurrent state architectures (like Mamba), // each KV cache cell can store the state for a whole sequence. // starting point to find the minimum seq_id used in the batch @@ -2289,7 +2333,7 @@ static bool llama_kv_cache_seq_rm( if (p1 < 0) p1 = std::numeric_limits::max(); // models like Mamba can't have a state partially erased - if (cache.unlimited) { + if (cache.recurrent) { if (seq_id >= (int64_t) cache.size) { // could be fatal return false; @@ -2341,7 +2385,7 @@ static void llama_kv_cache_seq_cp( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.unlimited) { + if (cache.recurrent) { if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { seq_id_src = cache.cells[seq_id_src].src; GGML_ASSERT((uint32_t) seq_id_src < cache.size); @@ -2403,7 +2447,7 @@ static void llama_kv_cache_seq_add( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.unlimited) { + if (cache.recurrent) { // for Mamba-like models, only the pos needs to be shifted if (0 <= seq_id && seq_id < (int64_t) cache.size) { llama_kv_cell & cell = cache.cells[seq_id]; @@ -2447,7 +2491,7 @@ static void llama_kv_cache_seq_div( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.unlimited) { + if (cache.recurrent) { // for Mamba-like models, only the pos needs to be changed if (0 <= seq_id && seq_id < (int64_t) cache.size) { llama_kv_cell & cell = cache.cells[seq_id]; @@ -3277,7 +3321,7 @@ static void llm_load_hparams( // sanity check for n_rot (optional) { - hparams.n_rot = hparams.n_embd / hparams.n_head; + hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); @@ -3290,10 +3334,10 @@ static void llm_load_hparams( // gpt-j n_rot = rotary_dim } - hparams.n_embd_head_k = hparams.n_embd / hparams.n_head; + hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); - hparams.n_embd_head_v = hparams.n_embd / hparams.n_head; + hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); // arch-specific KVs @@ -3545,7 +3589,13 @@ static void llm_load_hparams( } break; case LLM_ARCH_MAMBA: { + ml.get_key(LLM_KV_SSM_D_CONV, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_D_INNER, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_D_STATE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_DT_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { case 24: switch (hparams.n_embd) { @@ -3886,6 +3936,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); if (ml.n_elements >= 1e12) { @@ -4050,10 +4104,7 @@ static bool llm_load_tensors( const int64_t n_vocab_type = hparams.n_vocab_type; const int64_t n_ff = hparams.n_ff; - // Mamba uses these in its own way - if (model.arch != LLM_ARCH_MAMBA) { - GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); - } + GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); ggml_context * ctx_input = ctx_map.at(model.buft_input.buft); ggml_context * ctx_output = ctx_map.at(model.buft_output.buft); @@ -4792,12 +4843,11 @@ static bool llm_load_tensors( } break; case LLM_ARCH_MAMBA: { - const int64_t d_conv = hparams.n_embd_head_k + 1; - const int64_t d_state = hparams.n_embd_head_v; - const int64_t d_inner = hparams.n_head; - // TODO: allow loading dt_rank from the model config - // ceiling division - const int64_t dt_rank = (n_embd / 16) + (n_embd % 16 > 0); + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + // only an expansion factor of 2 is supported for now GGML_ASSERT(2 * n_embd == d_inner); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -5420,7 +5470,7 @@ struct llm_build_context { norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), - kv_head (worst_case ? (kv_self.unlimited ? 0 : kv_self.size - n_tokens) : kv_self.head), + kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), @@ -5473,8 +5523,8 @@ struct llm_build_context { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); for (int il = 0; il < n_layer; ++il) { - ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size); - ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size); + ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); + ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy); ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy); @@ -8048,12 +8098,11 @@ struct llm_build_context { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); const int64_t d_model = n_embd; - const int64_t d_inner = n_head; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; GGML_ASSERT(2 * d_model == d_inner); - const int64_t d_conv = n_embd_head_k + 1; - const int64_t d_state = n_embd_head_v; - // ceiling division - const int64_t dt_rank = (d_model / 16) + (d_model % 16 > 0); + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -8063,10 +8112,9 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); for (int il = 0; il < n_layer; ++il) { - // (ab)using the kv cache to store the state - // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], (d_conv-1)*(d_inner), kv_self.size); - ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], (d_state)*(d_inner), kv_self.size); + // (ab)using the KV cache to store the states + ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); + ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); // clear states of sequences which are starting at the beginning of this batch { @@ -8501,7 +8549,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (kv_self.unlimited) { + if (kv_self.recurrent) { const int64_t n_kv = kv_self.n; { @@ -8667,7 +8715,7 @@ static int llama_decode_internal( return 1; } - if (!kv_self.unlimited) { + if (!kv_self.recurrent) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important @@ -9056,7 +9104,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } - if (lctx.kv_self.unlimited && lctx.kv_self.do_copy) { + if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { llama_set_s_copy(lctx); { @@ -12725,7 +12773,7 @@ struct llama_context * llama_new_context_with_model( // graph inputs { ggml_init_params init_params = { - /* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.unlimited)), + /* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.recurrent)), /* .mem_buffer */ nullptr, /* .no_alloc */ true, }; @@ -12739,7 +12787,7 @@ struct llama_context * llama_new_context_with_model( ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size); ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); - if (ctx->kv_self.unlimited) { + if (ctx->kv_self.recurrent) { ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size); ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size); ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch); @@ -12753,7 +12801,7 @@ struct llama_context * llama_new_context_with_model( ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); ggml_set_name(ctx->inp_mean, "inp_mean"); ggml_set_name(ctx->inp_cls, "inp_cls"); - if (ctx->kv_self.unlimited) { + if (ctx->kv_self.recurrent) { ggml_set_name(ctx->inp_s_copy, "inp_s_copy"); ggml_set_name(ctx->inp_s_mask, "inp_s_mask"); ggml_set_name(ctx->inp_s_seq, "inp_s_seq");