From 92653d05fdf8039439810380f323f32d7425ba96 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 9 Dec 2024 13:42:54 -0700 Subject: [PATCH] WIP: Partial work towards separate hybrid cache This also seems like not _quite_ the right direction Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- include/llama.h | 2 ++ src/llama.cpp | 83 ++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/include/llama.h b/include/llama.h index 90791d5f5..3f9e72f3d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -489,6 +489,8 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid (like Bamba, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( diff --git a/src/llama.cpp b/src/llama.cpp index 80f767282..2d0fe1e4d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3348,6 +3348,10 @@ struct llama_context { struct llama_kv_cache kv_self; struct llama_control_vector cvec; + // Hybrid attention/ssm models use kv cache differently for attention/ssm + // layers with different kv_size values + struct llama_kv_cache kv_hybrid; + std::unordered_map lora_adapters; std::vector backends; @@ -3511,7 +3515,8 @@ static bool llama_kv_cache_init( ggml_type type_k, ggml_type type_v, uint32_t kv_size, - bool offload) { + bool offload, + bool recurrent) { const llama_model & model = ctx->model; const llama_cparams & cparams = ctx->cparams; @@ -3521,7 +3526,7 @@ static bool llama_kv_cache_init( cache.has_shift = false; - cache.recurrent = llama_model_is_recurrent(&model); + cache.recurrent = recurrent; cache.v_trans = !cache.recurrent && !cparams.flash_attn; cache.head = 0; @@ -9749,7 +9754,7 @@ static void llm_build_kv_store( } else { // note: the V cache is transposed when not using flash attention v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv.v_l[il]), + (kv.size)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); v_cur = ggml_transpose(ctx, v_cur); @@ -10421,10 +10426,11 @@ static struct ggml_tensor * llm_build_mamba2( int32_t kv_head, int32_t n_kv, const llm_build_cb & cb, - int il) { + int il, + bool hybrid = false) { const llama_model & model = lctx.model; const llama_hparams & hparams = model.hparams; - const llama_kv_cache & kv = lctx.kv_self; + const llama_kv_cache & kv = hybrid ? lctx.kv_hybrid : lctx.kv_self; const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; @@ -10712,6 +10718,7 @@ struct llm_build_context { const llama_cparams & cparams; const llama_ubatch & ubatch; const llama_kv_cache & kv_self; + const llama_kv_cache & kv_hybrid; const int64_t n_embd; const int64_t n_layer; @@ -10736,11 +10743,14 @@ struct llm_build_context { const float norm_rms_eps; const int32_t n_tokens; - const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) + const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) + const int32_t n_kv_hybrid; // size of KV cache to consider (n_kv_hybrid <= kv_hybrid.size) const int32_t n_outputs; const int32_t n_outputs_enc; - const int32_t kv_head; // index of where we store new KV data in the cache - const int32_t rs_zero; // the first zero-ed recurrent state + const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t kv_head_hybrid; // index of where we store new KV data in the hybrid cache + const int32_t rs_zero; // the first zero-ed recurrent state + const int32_t rs_zero_hybrid; // the first zero-ed recurrent state const int32_t n_ctx_orig; const bool flash_attn; @@ -10766,6 +10776,7 @@ struct llm_build_context { cparams (lctx.cparams), ubatch (ubatch), kv_self (lctx.kv_self), + kv_hybrid (lctx.kv_hybrid), n_embd (hparams.n_embd), n_layer (hparams.n_layer), n_rot (hparams.n_rot), @@ -10788,10 +10799,13 @@ struct llm_build_context { norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (ubatch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), + n_kv_hybrid (worst_case ? kv_hybrid.size : kv_self.n), n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), + kv_head_hybrid (worst_case ? (kv_hybrid.recurrent ? 0 : kv_hybrid.size - n_tokens) : kv_hybrid.head), rs_zero (kv_self.rs_z), + rs_zero_hybrid (kv_hybrid.rs_z), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), @@ -14687,7 +14701,8 @@ struct llm_build_context { if (hparams.recurrent_layer(il)) { // ssm layer cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy, - rs_zero, kv_head, n_kv, cb, il); + rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il, true); + cb(cur, "mamba_out", il); } else { // attention layer // @@ -20325,14 +20340,23 @@ struct llama_context * llama_new_context_with_model( ctx->is_encoding = llama_model_has_encoder(model); uint32_t kv_size = cparams.n_ctx; + uint32_t kv_size_hybrid = 0; ggml_type type_k = params.type_k; ggml_type type_v = params.type_v; + const bool recurrent = llama_model_is_recurrent(model); + const bool hybrid = llama_model_is_hybrid(model); // Mamba only needs a constant number of KV cache cells per sequence - if (llama_model_is_recurrent(model)) { + if (recurrent) { // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); + // NOTE: Hybrid models will use the hybrid cache for the SSM layers + if (hybrid) { + kv_size_hybrid = std::max((uint32_t) 1, params.n_seq_max); + } else { + kv_size = std::max((uint32_t) 1, params.n_seq_max); + } // it's probably best to keep as much precision as possible for the states + // TODO: should types be different for the two caches? type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states } @@ -20389,24 +20413,44 @@ struct llama_context * llama_new_context_with_model( llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data); - if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { + // the self cache is recurrent IFF the model is recurrent, but not hybrid + if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv, recurrent && !hybrid)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; } - { + // Log cache memory usage size_t memory_size_k = 0; size_t memory_size_v = 0; - for (auto & k : ctx->kv_self.k_l) { memory_size_k += ggml_nbytes(k); } - for (auto & v : ctx->kv_self.v_l) { memory_size_v += ggml_nbytes(v); } + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } + // For hybrid models, initialize the hybrid kv cache + if (kv_size_hybrid > 0 && !llama_kv_cache_init(ctx->kv_hybrid, ctx, type_k, type_v, kv_size_hybrid, cparams.offload_kqv, true)) { + LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); + llama_free(ctx); + return nullptr; + } + { + // Log hybrid cache memory usage + size_t memory_size_k = 0; + size_t memory_size_v = 0; + for (auto & k : ctx->kv_hybrid.k_l) { + memory_size_k += ggml_nbytes(k); + } + for (auto & v : ctx->kv_hybrid.v_l) { + memory_size_v += ggml_nbytes(v); + } LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), @@ -20763,6 +20807,15 @@ bool llama_model_is_recurrent(const struct llama_model * model) { } } +bool llama_model_is_hybrid(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_BAMBA: + return true; + default: + return false; + } +} + uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out,