fix: per-layer recurrent embd_[kv]_s
For hybrid models, this value should be 0 for the non-recurrent layers Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
f2478bcab5
commit
d3a34e0282
1 changed files with 28 additions and 21 deletions
|
@ -2485,7 +2485,7 @@ struct llama_hparams {
|
|||
uint32_t ssm_head_dim = 0;
|
||||
|
||||
// for hybrid state space models
|
||||
std::array<bool, LLAMA_MAX_LAYERS> ssm_layer_arr;
|
||||
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
||||
|
||||
float f_clamp_kqv = 0.0f;
|
||||
float f_max_alibi_bias = 0.0f;
|
||||
|
@ -2544,7 +2544,7 @@ struct llama_hparams {
|
|||
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
|
||||
if (this->ssm_head_dim != other.ssm_head_dim) return true;
|
||||
|
||||
if (this->ssm_layer_arr != other.ssm_layer_arr) return true;
|
||||
if (this->recurrent_layer_arr != other.recurrent_layer_arr) return true;
|
||||
|
||||
if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true;
|
||||
if (this->time_mix_extra_dim != other.time_mix_extra_dim) return true;
|
||||
|
@ -2616,30 +2616,34 @@ struct llama_hparams {
|
|||
return n_embd_head_v * n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
|
||||
uint32_t n_embd_k_s(uint32_t il = 0) const { // dimension of the rolling state embeddings
|
||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||
if (!recurrent_layer(il)) {
|
||||
return 0;
|
||||
}
|
||||
if (wkv_head_size != 0) {
|
||||
// for RWKV models
|
||||
return 2 * n_embd;
|
||||
} else {
|
||||
// TODO: maybe support other convolution strides than 1
|
||||
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
||||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
|
||||
}
|
||||
// TODO: maybe support other convolution strides than 1
|
||||
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
||||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
|
||||
}
|
||||
|
||||
uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
|
||||
uint32_t n_embd_v_s(uint32_t il = 0) const { // dimension of the recurrent state embeddings
|
||||
if (!recurrent_layer(il)) {
|
||||
return 0;
|
||||
}
|
||||
if (wkv_head_size != 0) {
|
||||
// corresponds to RWKV's wkv_states size
|
||||
return n_embd * wkv_head_size;
|
||||
} else {
|
||||
// corresponds to Mamba's ssm_states size
|
||||
return ssm_d_state * ssm_d_inner;
|
||||
}
|
||||
// corresponds to Mamba's ssm_states size
|
||||
return ssm_d_state * ssm_d_inner;
|
||||
}
|
||||
|
||||
bool ssm_layer(uint32_t il) const {
|
||||
return ssm_layer_arr[il];
|
||||
bool recurrent_layer(uint32_t il) const {
|
||||
return recurrent_layer_arr[il];
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -3555,8 +3559,8 @@ static bool llama_kv_cache_init(
|
|||
cache.v_l.reserve(n_layer);
|
||||
|
||||
for (int i = 0; i < (int) n_layer; i++) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
|
||||
|
||||
ggml_backend_buffer_type_t buft;
|
||||
if (offload) {
|
||||
|
@ -5509,7 +5513,10 @@ static void llm_load_hparams(
|
|||
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
||||
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
|
||||
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
|
||||
std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), false);
|
||||
std::fill(
|
||||
hparams.recurrent_layer_arr.begin(),
|
||||
hparams.recurrent_layer_arr.end(),
|
||||
llama_model_is_recurrent(&model));
|
||||
|
||||
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
|
||||
|
@ -5999,12 +6006,12 @@ static void llm_load_hparams(
|
|||
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
|
||||
|
||||
// Attention params
|
||||
std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true);
|
||||
std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true);
|
||||
std::vector<uint32_t> attn_layer_indices;
|
||||
ml.get_arr(LLM_KV_ATTENTION_LAYER_INDICES, attn_layer_indices);
|
||||
for (const auto attn_idx : attn_layer_indices) {
|
||||
GGML_ASSERT(attn_idx < hparams.n_layer);
|
||||
hparams.ssm_layer_arr[attn_idx] = false;
|
||||
hparams.recurrent_layer_arr[attn_idx] = false;
|
||||
// Correctly set n_head and n_head_kv for attention layers
|
||||
hparams.n_head_arr[attn_idx] = n_head_attn;
|
||||
hparams.n_head_kv_arr[attn_idx] = n_head_kv_attn;
|
||||
|
@ -7162,7 +7169,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||
}
|
||||
|
||||
if (model.arch == LLM_ARCH_BAMBA) {
|
||||
LLAMA_LOG_INFO("%s: ssm_layer_arr = %s\n", __func__, print_f([&](uint32_t il) { return uint32_t(hparams.ssm_layer(il)); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: recurrent_layer_arr = %s\n", __func__, print_f([&](uint32_t il) { return uint32_t(hparams.recurrent_layer(il)); }, hparams.n_layer).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8769,7 +8776,7 @@ static bool llm_load_tensors(
|
|||
// norm
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
if (hparams.ssm_layer(i)) {
|
||||
if (hparams.recurrent_layer(i)) {
|
||||
// ssm layers
|
||||
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
|
||||
layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
@ -14677,7 +14684,7 @@ struct llm_build_context {
|
|||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
if (hparams.ssm_layer(il)) {
|
||||
if (hparams.recurrent_layer(il)) {
|
||||
// ssm layer
|
||||
cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy,
|
||||
rs_zero, kv_head, n_kv, cb, il);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue