WIP: Partial work towards separate hybrid cache

This also seems like not _quite_ the right direction

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2024-12-09 13:42:54 -07:00
parent d3a34e0282
commit 92653d05fd
2 changed files with 70 additions and 15 deletions

View file

@ -489,6 +489,8 @@ extern "C" {
// Returns true if the model is recurrent (like Mamba, RWKV, etc.) // Returns true if the model is recurrent (like Mamba, RWKV, etc.)
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
// Returns true if the model is hybrid (like Bamba, etc.)
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
// Returns 0 on success // Returns 0 on success
LLAMA_API uint32_t llama_model_quantize( LLAMA_API uint32_t llama_model_quantize(

View file

@ -3348,6 +3348,10 @@ struct llama_context {
struct llama_kv_cache kv_self; struct llama_kv_cache kv_self;
struct llama_control_vector cvec; struct llama_control_vector cvec;
// Hybrid attention/ssm models use kv cache differently for attention/ssm
// layers with different kv_size values
struct llama_kv_cache kv_hybrid;
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters; std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
std::vector<ggml_backend_ptr> backends; std::vector<ggml_backend_ptr> backends;
@ -3511,7 +3515,8 @@ static bool llama_kv_cache_init(
ggml_type type_k, ggml_type type_k,
ggml_type type_v, ggml_type type_v,
uint32_t kv_size, uint32_t kv_size,
bool offload) { bool offload,
bool recurrent) {
const llama_model & model = ctx->model; const llama_model & model = ctx->model;
const llama_cparams & cparams = ctx->cparams; const llama_cparams & cparams = ctx->cparams;
@ -3521,7 +3526,7 @@ static bool llama_kv_cache_init(
cache.has_shift = false; cache.has_shift = false;
cache.recurrent = llama_model_is_recurrent(&model); cache.recurrent = recurrent;
cache.v_trans = !cache.recurrent && !cparams.flash_attn; cache.v_trans = !cache.recurrent && !cparams.flash_attn;
cache.head = 0; cache.head = 0;
@ -9749,7 +9754,7 @@ static void llm_build_kv_store(
} else { } else {
// note: the V cache is transposed when not using flash attention // note: the V cache is transposed when not using flash attention
v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv.v_l[il]), (kv.size)*ggml_element_size(kv.v_l[il]),
(kv_head)*ggml_element_size(kv.v_l[il])); (kv_head)*ggml_element_size(kv.v_l[il]));
v_cur = ggml_transpose(ctx, v_cur); v_cur = ggml_transpose(ctx, v_cur);
@ -10421,10 +10426,11 @@ static struct ggml_tensor * llm_build_mamba2(
int32_t kv_head, int32_t kv_head,
int32_t n_kv, int32_t n_kv,
const llm_build_cb & cb, const llm_build_cb & cb,
int il) { int il,
bool hybrid = false) {
const llama_model & model = lctx.model; const llama_model & model = lctx.model;
const llama_hparams & hparams = model.hparams; const llama_hparams & hparams = model.hparams;
const llama_kv_cache & kv = lctx.kv_self; const llama_kv_cache & kv = hybrid ? lctx.kv_hybrid : lctx.kv_self;
const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state; const int64_t d_state = hparams.ssm_d_state;
@ -10712,6 +10718,7 @@ struct llm_build_context {
const llama_cparams & cparams; const llama_cparams & cparams;
const llama_ubatch & ubatch; const llama_ubatch & ubatch;
const llama_kv_cache & kv_self; const llama_kv_cache & kv_self;
const llama_kv_cache & kv_hybrid;
const int64_t n_embd; const int64_t n_embd;
const int64_t n_layer; const int64_t n_layer;
@ -10736,11 +10743,14 @@ struct llm_build_context {
const float norm_rms_eps; const float norm_rms_eps;
const int32_t n_tokens; const int32_t n_tokens;
const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
const int32_t n_kv_hybrid; // size of KV cache to consider (n_kv_hybrid <= kv_hybrid.size)
const int32_t n_outputs; const int32_t n_outputs;
const int32_t n_outputs_enc; const int32_t n_outputs_enc;
const int32_t kv_head; // index of where we store new KV data in the cache const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t rs_zero; // the first zero-ed recurrent state const int32_t kv_head_hybrid; // index of where we store new KV data in the hybrid cache
const int32_t rs_zero; // the first zero-ed recurrent state
const int32_t rs_zero_hybrid; // the first zero-ed recurrent state
const int32_t n_ctx_orig; const int32_t n_ctx_orig;
const bool flash_attn; const bool flash_attn;
@ -10766,6 +10776,7 @@ struct llm_build_context {
cparams (lctx.cparams), cparams (lctx.cparams),
ubatch (ubatch), ubatch (ubatch),
kv_self (lctx.kv_self), kv_self (lctx.kv_self),
kv_hybrid (lctx.kv_hybrid),
n_embd (hparams.n_embd), n_embd (hparams.n_embd),
n_layer (hparams.n_layer), n_layer (hparams.n_layer),
n_rot (hparams.n_rot), n_rot (hparams.n_rot),
@ -10788,10 +10799,13 @@ struct llm_build_context {
norm_rms_eps (hparams.f_norm_rms_eps), norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (ubatch.n_tokens), n_tokens (ubatch.n_tokens),
n_kv (worst_case ? kv_self.size : kv_self.n), n_kv (worst_case ? kv_self.size : kv_self.n),
n_kv_hybrid (worst_case ? kv_hybrid.size : kv_self.n),
n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs (worst_case ? n_tokens : lctx.n_outputs),
n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
kv_head_hybrid (worst_case ? (kv_hybrid.recurrent ? 0 : kv_hybrid.size - n_tokens) : kv_hybrid.head),
rs_zero (kv_self.rs_z), rs_zero (kv_self.rs_z),
rs_zero_hybrid (kv_hybrid.rs_z),
n_ctx_orig (cparams.n_ctx_orig_yarn), n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn), flash_attn (cparams.flash_attn),
pooling_type (cparams.pooling_type), pooling_type (cparams.pooling_type),
@ -14687,7 +14701,8 @@ struct llm_build_context {
if (hparams.recurrent_layer(il)) { if (hparams.recurrent_layer(il)) {
// ssm layer // ssm layer
cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy, cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy,
rs_zero, kv_head, n_kv, cb, il); rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il, true);
cb(cur, "mamba_out", il);
} else { } else {
// attention layer // // attention layer //
@ -20325,14 +20340,23 @@ struct llama_context * llama_new_context_with_model(
ctx->is_encoding = llama_model_has_encoder(model); ctx->is_encoding = llama_model_has_encoder(model);
uint32_t kv_size = cparams.n_ctx; uint32_t kv_size = cparams.n_ctx;
uint32_t kv_size_hybrid = 0;
ggml_type type_k = params.type_k; ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v; ggml_type type_v = params.type_v;
const bool recurrent = llama_model_is_recurrent(model);
const bool hybrid = llama_model_is_hybrid(model);
// Mamba only needs a constant number of KV cache cells per sequence // Mamba only needs a constant number of KV cache cells per sequence
if (llama_model_is_recurrent(model)) { if (recurrent) {
// Mamba needs at least as many KV cells as there are sequences kept at any time // Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max); // NOTE: Hybrid models will use the hybrid cache for the SSM layers
if (hybrid) {
kv_size_hybrid = std::max((uint32_t) 1, params.n_seq_max);
} else {
kv_size = std::max((uint32_t) 1, params.n_seq_max);
}
// it's probably best to keep as much precision as possible for the states // it's probably best to keep as much precision as possible for the states
// TODO: should types be different for the two caches?
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
} }
@ -20389,24 +20413,44 @@ struct llama_context * llama_new_context_with_model(
llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data); llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data);
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { // the self cache is recurrent IFF the model is recurrent, but not hybrid
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv, recurrent && !hybrid)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx); llama_free(ctx);
return nullptr; return nullptr;
} }
{ {
// Log cache memory usage
size_t memory_size_k = 0; size_t memory_size_k = 0;
size_t memory_size_v = 0; size_t memory_size_v = 0;
for (auto & k : ctx->kv_self.k_l) { for (auto & k : ctx->kv_self.k_l) {
memory_size_k += ggml_nbytes(k); memory_size_k += ggml_nbytes(k);
} }
for (auto & v : ctx->kv_self.v_l) { for (auto & v : ctx->kv_self.v_l) {
memory_size_v += ggml_nbytes(v); memory_size_v += ggml_nbytes(v);
} }
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
// For hybrid models, initialize the hybrid kv cache
if (kv_size_hybrid > 0 && !llama_kv_cache_init(ctx->kv_hybrid, ctx, type_k, type_v, kv_size_hybrid, cparams.offload_kqv, true)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
}
{
// Log hybrid cache memory usage
size_t memory_size_k = 0;
size_t memory_size_v = 0;
for (auto & k : ctx->kv_hybrid.k_l) {
memory_size_k += ggml_nbytes(k);
}
for (auto & v : ctx->kv_hybrid.v_l) {
memory_size_v += ggml_nbytes(v);
}
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
@ -20763,6 +20807,15 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
} }
} }
bool llama_model_is_hybrid(const struct llama_model * model) {
switch (model->arch) {
case LLM_ARCH_BAMBA:
return true;
default:
return false;
}
}
uint32_t llama_model_quantize( uint32_t llama_model_quantize(
const char * fname_inp, const char * fname_inp,
const char * fname_out, const char * fname_out,