From d3a34e0282579ba08d773bb7760f4a6a2060cb8a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 9 Dec 2024 15:51:32 -0700 Subject: [PATCH] fix: per-layer recurrent embd_[kv]_s For hybrid models, this value should be 0 for the non-recurrent layers Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 49 ++++++++++++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 0d97e54c3..80f767282 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2485,7 +2485,7 @@ struct llama_hparams { uint32_t ssm_head_dim = 0; // for hybrid state space models - std::array ssm_layer_arr; + std::array recurrent_layer_arr; float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; @@ -2544,7 +2544,7 @@ struct llama_hparams { if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; if (this->ssm_head_dim != other.ssm_head_dim) return true; - if (this->ssm_layer_arr != other.ssm_layer_arr) return true; + if (this->recurrent_layer_arr != other.recurrent_layer_arr) return true; if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true; if (this->time_mix_extra_dim != other.time_mix_extra_dim) return true; @@ -2616,30 +2616,34 @@ struct llama_hparams { return n_embd_head_v * n_head_kv; } - uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings + uint32_t n_embd_k_s(uint32_t il = 0) const { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // for RWKV models return 2 * n_embd; - } else { - // TODO: maybe support other convolution strides than 1 - // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); } + // TODO: maybe support other convolution strides than 1 + // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); } - uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings + uint32_t n_embd_v_s(uint32_t il = 0) const { // dimension of the recurrent state embeddings + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; - } else { - // corresponds to Mamba's ssm_states size - return ssm_d_state * ssm_d_inner; } + // corresponds to Mamba's ssm_states size + return ssm_d_state * ssm_d_inner; } - bool ssm_layer(uint32_t il) const { - return ssm_layer_arr[il]; + bool recurrent_layer(uint32_t il) const { + return recurrent_layer_arr[il]; } }; @@ -3555,8 +3559,8 @@ static bool llama_kv_cache_init( cache.v_l.reserve(n_layer); for (int i = 0; i < (int) n_layer; i++) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i); ggml_backend_buffer_type_t buft; if (offload) { @@ -5509,7 +5513,10 @@ static void llm_load_hparams( std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); - std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), false); + std::fill( + hparams.recurrent_layer_arr.begin(), + hparams.recurrent_layer_arr.end(), + llama_model_is_recurrent(&model)); ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer); @@ -5999,12 +6006,12 @@ static void llm_load_hparams( std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); // Attention params - std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true); + std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); std::vector attn_layer_indices; ml.get_arr(LLM_KV_ATTENTION_LAYER_INDICES, attn_layer_indices); for (const auto attn_idx : attn_layer_indices) { GGML_ASSERT(attn_idx < hparams.n_layer); - hparams.ssm_layer_arr[attn_idx] = false; + hparams.recurrent_layer_arr[attn_idx] = false; // Correctly set n_head and n_head_kv for attention layers hparams.n_head_arr[attn_idx] = n_head_attn; hparams.n_head_kv_arr[attn_idx] = n_head_kv_attn; @@ -7162,7 +7169,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { } if (model.arch == LLM_ARCH_BAMBA) { - LLAMA_LOG_INFO("%s: ssm_layer_arr = %s\n", __func__, print_f([&](uint32_t il) { return uint32_t(hparams.ssm_layer(il)); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: recurrent_layer_arr = %s\n", __func__, print_f([&](uint32_t il) { return uint32_t(hparams.recurrent_layer(il)); }, hparams.n_layer).c_str()); } } @@ -8769,7 +8776,7 @@ static bool llm_load_tensors( // norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (hparams.ssm_layer(i)) { + if (hparams.recurrent_layer(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, llama_model_loader::TENSOR_NOT_REQUIRED); @@ -14677,7 +14684,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - if (hparams.ssm_layer(il)) { + if (hparams.recurrent_layer(il)) { // ssm layer cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy, rs_zero, kv_head, n_kv, cb, il);