From e7b1abbc0a0e6aaf34ecbc3545cbaabd8f7e3592 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 5 Dec 2024 11:04:54 -0700 Subject: [PATCH] feat(bamba): Partially complete work on constructing the forward graph There are still problems at inference around matrix dimensions not lining up, so there are likely still places where the per-layer sizes are not being used correctly. Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 181 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 164 insertions(+), 17 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 0e568779b..ade7e52f3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5988,6 +5988,16 @@ static void llm_load_hparams( ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim); + // Zero-out n_head_arr and n_head_kv_arr since SSM layers don't + // have attention heads. We'll set them correctly below once we + // know which layers are attention layers + // NOTE: It's important that this happens after n_embd_head_[kv] + // are set above! + const auto n_head_attn = hparams.n_head(); + const auto n_head_kv_attn = hparams.n_head_kv(); + std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); + std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); + // Attention params std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true); std::vector attn_layer_indices; @@ -5995,6 +6005,9 @@ static void llm_load_hparams( for (const auto attn_idx : attn_layer_indices) { GGML_ASSERT(attn_idx < hparams.n_layer); hparams.ssm_layer_arr[attn_idx] = false; + // Correctly set n_head and n_head_kv for attention layers + hparams.n_head_arr[attn_idx] = n_head_attn; + hparams.n_head_kv_arr[attn_idx] = n_head_kv_attn; } // Feed forward params @@ -8726,15 +8739,13 @@ static bool llm_load_tensors( case LLM_ARCH_BAMBA: { // mamba2 Mixer SSM params - // TODO: Why are these int64_t and not uint32_t? + // NOTE: int64_t for tensor dimensions const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; const int64_t n_group = hparams.ssm_n_group; - const int64_t head_count = hparams.ssm_head_count; - const int64_t head_dim = hparams.ssm_head_dim; - const int64_t chunk_size = hparams.ssm_chunk_size; - const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + head_count; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; // only an expansion factor of 2 is supported for now GGML_ASSERT(2 * n_embd == d_inner); @@ -8766,11 +8777,11 @@ static bool llm_load_tensors( layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {head_count}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, head_count}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, head_count}, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); @@ -8778,14 +8789,17 @@ static bool llm_load_tensors( layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); } else { // attention layers (with optional bias) - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); } @@ -10408,7 +10422,7 @@ static struct ggml_tensor * llm_build_mamba2( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = d_inner / n_head; + const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_inner / n_head : hparams.ssm_head_dim; const int64_t n_group = hparams.ssm_n_group; const int64_t n_seqs = batch.n_seqs; @@ -14633,6 +14647,134 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_bamba() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); + + struct ggml_tensor * state_copy = build_inp_s_copy(); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + if (hparams.ssm_layer(il)) { + // ssm layer + cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy, + rs_zero, kv_head, n_kv, cb, il); + } else { + // attention layer // + + // rope freq factors + struct ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed forward + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + // residual + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_command_r() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -17215,6 +17357,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_mamba(/* version */ 2); } break; + case LLM_ARCH_BAMBA: + { + result = llm.build_bamba(); + } break; case LLM_ARCH_XVERSE: { result = llm.build_xverse(); @@ -20601,6 +20747,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) { switch (model->arch) { case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA2: + case LLM_ARCH_BAMBA: case LLM_ARCH_RWKV6: return true; default: