feat(bamba): Partially complete work on constructing the forward graph

There are still problems at inference around matrix dimensions not lining
up, so there are likely still places where the per-layer sizes are not
being used correctly.

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2024-12-05 11:04:54 -07:00
parent 41fc019057
commit e7b1abbc0a

View file

@ -5988,6 +5988,16 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim); ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim);
// Zero-out n_head_arr and n_head_kv_arr since SSM layers don't
// have attention heads. We'll set them correctly below once we
// know which layers are attention layers
// NOTE: It's important that this happens after n_embd_head_[kv]
// are set above!
const auto n_head_attn = hparams.n_head();
const auto n_head_kv_attn = hparams.n_head_kv();
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
// Attention params // Attention params
std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true); std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true);
std::vector<uint32_t> attn_layer_indices; std::vector<uint32_t> attn_layer_indices;
@ -5995,6 +6005,9 @@ static void llm_load_hparams(
for (const auto attn_idx : attn_layer_indices) { for (const auto attn_idx : attn_layer_indices) {
GGML_ASSERT(attn_idx < hparams.n_layer); GGML_ASSERT(attn_idx < hparams.n_layer);
hparams.ssm_layer_arr[attn_idx] = false; hparams.ssm_layer_arr[attn_idx] = false;
// Correctly set n_head and n_head_kv for attention layers
hparams.n_head_arr[attn_idx] = n_head_attn;
hparams.n_head_kv_arr[attn_idx] = n_head_kv_attn;
} }
// Feed forward params // Feed forward params
@ -8726,15 +8739,13 @@ static bool llm_load_tensors(
case LLM_ARCH_BAMBA: case LLM_ARCH_BAMBA:
{ {
// mamba2 Mixer SSM params // mamba2 Mixer SSM params
// TODO: Why are these int64_t and not uint32_t? // NOTE: int64_t for tensor dimensions
const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state; const int64_t d_state = hparams.ssm_d_state;
const int64_t n_ssm_head = hparams.ssm_dt_rank;
const int64_t n_group = hparams.ssm_n_group; const int64_t n_group = hparams.ssm_n_group;
const int64_t head_count = hparams.ssm_head_count; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
const int64_t head_dim = hparams.ssm_head_dim;
const int64_t chunk_size = hparams.ssm_chunk_size;
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + head_count;
// only an expansion factor of 2 is supported for now // only an expansion factor of 2 is supported for now
GGML_ASSERT(2 * n_embd == d_inner); GGML_ASSERT(2 * n_embd == d_inner);
@ -8766,11 +8777,11 @@ static bool llm_load_tensors(
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {head_count}, 0); layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0);
// no "weight" suffix for these // no "weight" suffix for these
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, head_count}, 0); layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0);
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, head_count}, 0); layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
@ -8778,14 +8789,17 @@ static bool llm_load_tensors(
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
} else { } else {
// attention layers (with optional bias) // attention layers (with optional bias)
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); const int64_t n_head_i = hparams.n_head(i);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0);
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
} }
@ -10408,7 +10422,7 @@ static struct ggml_tensor * llm_build_mamba2(
const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state; const int64_t d_state = hparams.ssm_d_state;
const int64_t n_head = hparams.ssm_dt_rank; const int64_t n_head = hparams.ssm_dt_rank;
const int64_t head_dim = d_inner / n_head; const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_inner / n_head : hparams.ssm_head_dim;
const int64_t n_group = hparams.ssm_n_group; const int64_t n_group = hparams.ssm_n_group;
const int64_t n_seqs = batch.n_seqs; const int64_t n_seqs = batch.n_seqs;
@ -14633,6 +14647,134 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_bamba() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
// {n_embd, n_tokens}
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
struct ggml_tensor * state_copy = build_inp_s_copy();
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
if (hparams.ssm_layer(il)) {
// ssm layer
cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy,
rs_zero, kv_head, n_kv, cb, il);
} else {
// attention layer //
// rope freq factors
struct ggml_tensor * rope_factors = build_rope_factors(il);
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed forward
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
// residual
cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
// final rmsnorm
cur = llm_build_norm(ctx0, inpL, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_command_r() { struct ggml_cgraph * build_command_r() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@ -17215,6 +17357,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_mamba(/* version */ 2); result = llm.build_mamba(/* version */ 2);
} break; } break;
case LLM_ARCH_BAMBA:
{
result = llm.build_bamba();
} break;
case LLM_ARCH_XVERSE: case LLM_ARCH_XVERSE:
{ {
result = llm.build_xverse(); result = llm.build_xverse();
@ -20601,6 +20747,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
switch (model->arch) { switch (model->arch) {
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
case LLM_ARCH_MAMBA2: case LLM_ARCH_MAMBA2:
case LLM_ARCH_BAMBA:
case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6:
return true; return true;
default: default: