WIP: Partial work towards separate hybrid cache
This also seems like not _quite_ the right direction Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
d3a34e0282
commit
92653d05fd
2 changed files with 70 additions and 15 deletions
|
@ -489,6 +489,8 @@ extern "C" {
|
|||
|
||||
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
|
||||
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
|
||||
// Returns true if the model is hybrid (like Bamba, etc.)
|
||||
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
|
||||
|
||||
// Returns 0 on success
|
||||
LLAMA_API uint32_t llama_model_quantize(
|
||||
|
|
|
@ -3348,6 +3348,10 @@ struct llama_context {
|
|||
struct llama_kv_cache kv_self;
|
||||
struct llama_control_vector cvec;
|
||||
|
||||
// Hybrid attention/ssm models use kv cache differently for attention/ssm
|
||||
// layers with different kv_size values
|
||||
struct llama_kv_cache kv_hybrid;
|
||||
|
||||
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
|
||||
|
||||
std::vector<ggml_backend_ptr> backends;
|
||||
|
@ -3511,7 +3515,8 @@ static bool llama_kv_cache_init(
|
|||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
uint32_t kv_size,
|
||||
bool offload) {
|
||||
bool offload,
|
||||
bool recurrent) {
|
||||
const llama_model & model = ctx->model;
|
||||
const llama_cparams & cparams = ctx->cparams;
|
||||
|
||||
|
@ -3521,7 +3526,7 @@ static bool llama_kv_cache_init(
|
|||
|
||||
cache.has_shift = false;
|
||||
|
||||
cache.recurrent = llama_model_is_recurrent(&model);
|
||||
cache.recurrent = recurrent;
|
||||
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
|
||||
|
||||
cache.head = 0;
|
||||
|
@ -9749,7 +9754,7 @@ static void llm_build_kv_store(
|
|||
} else {
|
||||
// note: the V cache is transposed when not using flash attention
|
||||
v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
|
||||
( n_ctx)*ggml_element_size(kv.v_l[il]),
|
||||
(kv.size)*ggml_element_size(kv.v_l[il]),
|
||||
(kv_head)*ggml_element_size(kv.v_l[il]));
|
||||
|
||||
v_cur = ggml_transpose(ctx, v_cur);
|
||||
|
@ -10421,10 +10426,11 @@ static struct ggml_tensor * llm_build_mamba2(
|
|||
int32_t kv_head,
|
||||
int32_t n_kv,
|
||||
const llm_build_cb & cb,
|
||||
int il) {
|
||||
int il,
|
||||
bool hybrid = false) {
|
||||
const llama_model & model = lctx.model;
|
||||
const llama_hparams & hparams = model.hparams;
|
||||
const llama_kv_cache & kv = lctx.kv_self;
|
||||
const llama_kv_cache & kv = hybrid ? lctx.kv_hybrid : lctx.kv_self;
|
||||
const int64_t d_conv = hparams.ssm_d_conv;
|
||||
const int64_t d_inner = hparams.ssm_d_inner;
|
||||
const int64_t d_state = hparams.ssm_d_state;
|
||||
|
@ -10712,6 +10718,7 @@ struct llm_build_context {
|
|||
const llama_cparams & cparams;
|
||||
const llama_ubatch & ubatch;
|
||||
const llama_kv_cache & kv_self;
|
||||
const llama_kv_cache & kv_hybrid;
|
||||
|
||||
const int64_t n_embd;
|
||||
const int64_t n_layer;
|
||||
|
@ -10736,11 +10743,14 @@ struct llm_build_context {
|
|||
const float norm_rms_eps;
|
||||
|
||||
const int32_t n_tokens;
|
||||
const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
|
||||
const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
|
||||
const int32_t n_kv_hybrid; // size of KV cache to consider (n_kv_hybrid <= kv_hybrid.size)
|
||||
const int32_t n_outputs;
|
||||
const int32_t n_outputs_enc;
|
||||
const int32_t kv_head; // index of where we store new KV data in the cache
|
||||
const int32_t rs_zero; // the first zero-ed recurrent state
|
||||
const int32_t kv_head; // index of where we store new KV data in the cache
|
||||
const int32_t kv_head_hybrid; // index of where we store new KV data in the hybrid cache
|
||||
const int32_t rs_zero; // the first zero-ed recurrent state
|
||||
const int32_t rs_zero_hybrid; // the first zero-ed recurrent state
|
||||
const int32_t n_ctx_orig;
|
||||
|
||||
const bool flash_attn;
|
||||
|
@ -10766,6 +10776,7 @@ struct llm_build_context {
|
|||
cparams (lctx.cparams),
|
||||
ubatch (ubatch),
|
||||
kv_self (lctx.kv_self),
|
||||
kv_hybrid (lctx.kv_hybrid),
|
||||
n_embd (hparams.n_embd),
|
||||
n_layer (hparams.n_layer),
|
||||
n_rot (hparams.n_rot),
|
||||
|
@ -10788,10 +10799,13 @@ struct llm_build_context {
|
|||
norm_rms_eps (hparams.f_norm_rms_eps),
|
||||
n_tokens (ubatch.n_tokens),
|
||||
n_kv (worst_case ? kv_self.size : kv_self.n),
|
||||
n_kv_hybrid (worst_case ? kv_hybrid.size : kv_self.n),
|
||||
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
|
||||
n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
|
||||
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
||||
kv_head_hybrid (worst_case ? (kv_hybrid.recurrent ? 0 : kv_hybrid.size - n_tokens) : kv_hybrid.head),
|
||||
rs_zero (kv_self.rs_z),
|
||||
rs_zero_hybrid (kv_hybrid.rs_z),
|
||||
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
||||
flash_attn (cparams.flash_attn),
|
||||
pooling_type (cparams.pooling_type),
|
||||
|
@ -14687,7 +14701,8 @@ struct llm_build_context {
|
|||
if (hparams.recurrent_layer(il)) {
|
||||
// ssm layer
|
||||
cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy,
|
||||
rs_zero, kv_head, n_kv, cb, il);
|
||||
rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il, true);
|
||||
cb(cur, "mamba_out", il);
|
||||
} else {
|
||||
// attention layer //
|
||||
|
||||
|
@ -20325,14 +20340,23 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->is_encoding = llama_model_has_encoder(model);
|
||||
|
||||
uint32_t kv_size = cparams.n_ctx;
|
||||
uint32_t kv_size_hybrid = 0;
|
||||
ggml_type type_k = params.type_k;
|
||||
ggml_type type_v = params.type_v;
|
||||
const bool recurrent = llama_model_is_recurrent(model);
|
||||
const bool hybrid = llama_model_is_hybrid(model);
|
||||
|
||||
// Mamba only needs a constant number of KV cache cells per sequence
|
||||
if (llama_model_is_recurrent(model)) {
|
||||
if (recurrent) {
|
||||
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
||||
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
||||
// NOTE: Hybrid models will use the hybrid cache for the SSM layers
|
||||
if (hybrid) {
|
||||
kv_size_hybrid = std::max((uint32_t) 1, params.n_seq_max);
|
||||
} else {
|
||||
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
||||
}
|
||||
// it's probably best to keep as much precision as possible for the states
|
||||
// TODO: should types be different for the two caches?
|
||||
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
||||
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
||||
}
|
||||
|
@ -20389,24 +20413,44 @@ struct llama_context * llama_new_context_with_model(
|
|||
|
||||
llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data);
|
||||
|
||||
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||
// the self cache is recurrent IFF the model is recurrent, but not hybrid
|
||||
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv, recurrent && !hybrid)) {
|
||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
{
|
||||
// Log cache memory usage
|
||||
size_t memory_size_k = 0;
|
||||
size_t memory_size_v = 0;
|
||||
|
||||
for (auto & k : ctx->kv_self.k_l) {
|
||||
memory_size_k += ggml_nbytes(k);
|
||||
}
|
||||
|
||||
for (auto & v : ctx->kv_self.v_l) {
|
||||
memory_size_v += ggml_nbytes(v);
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
||||
// For hybrid models, initialize the hybrid kv cache
|
||||
if (kv_size_hybrid > 0 && !llama_kv_cache_init(ctx->kv_hybrid, ctx, type_k, type_v, kv_size_hybrid, cparams.offload_kqv, true)) {
|
||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
{
|
||||
// Log hybrid cache memory usage
|
||||
size_t memory_size_k = 0;
|
||||
size_t memory_size_v = 0;
|
||||
for (auto & k : ctx->kv_hybrid.k_l) {
|
||||
memory_size_k += ggml_nbytes(k);
|
||||
}
|
||||
for (auto & v : ctx->kv_hybrid.v_l) {
|
||||
memory_size_v += ggml_nbytes(v);
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
|
@ -20763,6 +20807,15 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
|
|||
}
|
||||
}
|
||||
|
||||
bool llama_model_is_hybrid(const struct llama_model * model) {
|
||||
switch (model->arch) {
|
||||
case LLM_ARCH_BAMBA:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
const char * fname_out,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue