fix: per-layer recurrent embd_[kv]_s

For hybrid models, this value should be 0 for the non-recurrent layers

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2024-12-09 15:51:32 -07:00
parent f2478bcab5
commit d3a34e0282

View file

@ -2485,7 +2485,7 @@ struct llama_hparams {
uint32_t ssm_head_dim = 0;
// for hybrid state space models
std::array<bool, LLAMA_MAX_LAYERS> ssm_layer_arr;
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
@ -2544,7 +2544,7 @@ struct llama_hparams {
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
if (this->ssm_head_dim != other.ssm_head_dim) return true;
if (this->ssm_layer_arr != other.ssm_layer_arr) return true;
if (this->recurrent_layer_arr != other.recurrent_layer_arr) return true;
if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true;
if (this->time_mix_extra_dim != other.time_mix_extra_dim) return true;
@ -2616,30 +2616,34 @@ struct llama_hparams {
return n_embd_head_v * n_head_kv;
}
uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
uint32_t n_embd_k_s(uint32_t il = 0) const { // dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
if (!recurrent_layer(il)) {
return 0;
}
if (wkv_head_size != 0) {
// for RWKV models
return 2 * n_embd;
} else {
// TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
}
// TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
}
uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
uint32_t n_embd_v_s(uint32_t il = 0) const { // dimension of the recurrent state embeddings
if (!recurrent_layer(il)) {
return 0;
}
if (wkv_head_size != 0) {
// corresponds to RWKV's wkv_states size
return n_embd * wkv_head_size;
} else {
// corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner;
}
// corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner;
}
bool ssm_layer(uint32_t il) const {
return ssm_layer_arr[il];
bool recurrent_layer(uint32_t il) const {
return recurrent_layer_arr[il];
}
};
@ -3555,8 +3559,8 @@ static bool llama_kv_cache_init(
cache.v_l.reserve(n_layer);
for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
ggml_backend_buffer_type_t buft;
if (offload) {
@ -5509,7 +5513,10 @@ static void llm_load_hparams(
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), false);
std::fill(
hparams.recurrent_layer_arr.begin(),
hparams.recurrent_layer_arr.end(),
llama_model_is_recurrent(&model));
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
@ -5999,12 +6006,12 @@ static void llm_load_hparams(
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
// Attention params
std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true);
std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true);
std::vector<uint32_t> attn_layer_indices;
ml.get_arr(LLM_KV_ATTENTION_LAYER_INDICES, attn_layer_indices);
for (const auto attn_idx : attn_layer_indices) {
GGML_ASSERT(attn_idx < hparams.n_layer);
hparams.ssm_layer_arr[attn_idx] = false;
hparams.recurrent_layer_arr[attn_idx] = false;
// Correctly set n_head and n_head_kv for attention layers
hparams.n_head_arr[attn_idx] = n_head_attn;
hparams.n_head_kv_arr[attn_idx] = n_head_kv_attn;
@ -7162,7 +7169,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
}
if (model.arch == LLM_ARCH_BAMBA) {
LLAMA_LOG_INFO("%s: ssm_layer_arr = %s\n", __func__, print_f([&](uint32_t il) { return uint32_t(hparams.ssm_layer(il)); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: recurrent_layer_arr = %s\n", __func__, print_f([&](uint32_t il) { return uint32_t(hparams.recurrent_layer(il)); }, hparams.n_layer).c_str());
}
}
@ -8769,7 +8776,7 @@ static bool llm_load_tensors(
// norm
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
if (hparams.ssm_layer(i)) {
if (hparams.recurrent_layer(i)) {
// ssm layers
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, llama_model_loader::TENSOR_NOT_REQUIRED);
@ -14677,7 +14684,7 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
if (hparams.ssm_layer(il)) {
if (hparams.recurrent_layer(il)) {
// ssm layer
cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy,
rs_zero, kv_head, n_kv, cb, il);