feat(bamba): Partially complete work on constructing the forward graph
There are still problems at inference around matrix dimensions not lining up, so there are likely still places where the per-layer sizes are not being used correctly. Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
41fc019057
commit
e7b1abbc0a
1 changed files with 164 additions and 17 deletions
181
src/llama.cpp
181
src/llama.cpp
|
@ -5988,6 +5988,16 @@ static void llm_load_hparams(
|
||||||
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
|
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
|
||||||
ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim);
|
ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim);
|
||||||
|
|
||||||
|
// Zero-out n_head_arr and n_head_kv_arr since SSM layers don't
|
||||||
|
// have attention heads. We'll set them correctly below once we
|
||||||
|
// know which layers are attention layers
|
||||||
|
// NOTE: It's important that this happens after n_embd_head_[kv]
|
||||||
|
// are set above!
|
||||||
|
const auto n_head_attn = hparams.n_head();
|
||||||
|
const auto n_head_kv_attn = hparams.n_head_kv();
|
||||||
|
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
||||||
|
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
|
||||||
|
|
||||||
// Attention params
|
// Attention params
|
||||||
std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true);
|
std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true);
|
||||||
std::vector<uint32_t> attn_layer_indices;
|
std::vector<uint32_t> attn_layer_indices;
|
||||||
|
@ -5995,6 +6005,9 @@ static void llm_load_hparams(
|
||||||
for (const auto attn_idx : attn_layer_indices) {
|
for (const auto attn_idx : attn_layer_indices) {
|
||||||
GGML_ASSERT(attn_idx < hparams.n_layer);
|
GGML_ASSERT(attn_idx < hparams.n_layer);
|
||||||
hparams.ssm_layer_arr[attn_idx] = false;
|
hparams.ssm_layer_arr[attn_idx] = false;
|
||||||
|
// Correctly set n_head and n_head_kv for attention layers
|
||||||
|
hparams.n_head_arr[attn_idx] = n_head_attn;
|
||||||
|
hparams.n_head_kv_arr[attn_idx] = n_head_kv_attn;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Feed forward params
|
// Feed forward params
|
||||||
|
@ -8726,15 +8739,13 @@ static bool llm_load_tensors(
|
||||||
case LLM_ARCH_BAMBA:
|
case LLM_ARCH_BAMBA:
|
||||||
{
|
{
|
||||||
// mamba2 Mixer SSM params
|
// mamba2 Mixer SSM params
|
||||||
// TODO: Why are these int64_t and not uint32_t?
|
// NOTE: int64_t for tensor dimensions
|
||||||
const int64_t d_conv = hparams.ssm_d_conv;
|
const int64_t d_conv = hparams.ssm_d_conv;
|
||||||
const int64_t d_inner = hparams.ssm_d_inner;
|
const int64_t d_inner = hparams.ssm_d_inner;
|
||||||
const int64_t d_state = hparams.ssm_d_state;
|
const int64_t d_state = hparams.ssm_d_state;
|
||||||
|
const int64_t n_ssm_head = hparams.ssm_dt_rank;
|
||||||
const int64_t n_group = hparams.ssm_n_group;
|
const int64_t n_group = hparams.ssm_n_group;
|
||||||
const int64_t head_count = hparams.ssm_head_count;
|
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
|
||||||
const int64_t head_dim = hparams.ssm_head_dim;
|
|
||||||
const int64_t chunk_size = hparams.ssm_chunk_size;
|
|
||||||
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + head_count;
|
|
||||||
|
|
||||||
// only an expansion factor of 2 is supported for now
|
// only an expansion factor of 2 is supported for now
|
||||||
GGML_ASSERT(2 * n_embd == d_inner);
|
GGML_ASSERT(2 * n_embd == d_inner);
|
||||||
|
@ -8766,11 +8777,11 @@ static bool llm_load_tensors(
|
||||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
|
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
|
||||||
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {head_count}, 0);
|
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0);
|
||||||
|
|
||||||
// no "weight" suffix for these
|
// no "weight" suffix for these
|
||||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, head_count}, 0);
|
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0);
|
||||||
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, head_count}, 0);
|
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0);
|
||||||
|
|
||||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
|
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
|
||||||
|
|
||||||
|
@ -8778,14 +8789,17 @@ static bool llm_load_tensors(
|
||||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
|
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
|
||||||
} else {
|
} else {
|
||||||
// attention layers (with optional bias)
|
// attention layers (with optional bias)
|
||||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
const int64_t n_head_i = hparams.n_head(i);
|
||||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i);
|
||||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i);
|
||||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0);
|
||||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0);
|
||||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0);
|
||||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0);
|
||||||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10408,7 +10422,7 @@ static struct ggml_tensor * llm_build_mamba2(
|
||||||
const int64_t d_inner = hparams.ssm_d_inner;
|
const int64_t d_inner = hparams.ssm_d_inner;
|
||||||
const int64_t d_state = hparams.ssm_d_state;
|
const int64_t d_state = hparams.ssm_d_state;
|
||||||
const int64_t n_head = hparams.ssm_dt_rank;
|
const int64_t n_head = hparams.ssm_dt_rank;
|
||||||
const int64_t head_dim = d_inner / n_head;
|
const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_inner / n_head : hparams.ssm_head_dim;
|
||||||
const int64_t n_group = hparams.ssm_n_group;
|
const int64_t n_group = hparams.ssm_n_group;
|
||||||
const int64_t n_seqs = batch.n_seqs;
|
const int64_t n_seqs = batch.n_seqs;
|
||||||
|
|
||||||
|
@ -14633,6 +14647,134 @@ struct llm_build_context {
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_bamba() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
// {n_embd, n_tokens}
|
||||||
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
||||||
|
|
||||||
|
struct ggml_tensor * state_copy = build_inp_s_copy();
|
||||||
|
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
|
||||||
|
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
if (hparams.ssm_layer(il)) {
|
||||||
|
// ssm layer
|
||||||
|
cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy,
|
||||||
|
rs_zero, kv_head, n_kv, cb, il);
|
||||||
|
} else {
|
||||||
|
// attention layer //
|
||||||
|
|
||||||
|
// rope freq factors
|
||||||
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||||
|
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
if (model.layers[il].bq) {
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
if (model.layers[il].bk) {
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
if (model.layers[il].bv) {
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
Qcur = ggml_rope_ext(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_ext(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1) {
|
||||||
|
// skip computing output for unused tokens
|
||||||
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
// feed forward
|
||||||
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||||
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||||
|
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||||
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
// residual
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
// final rmsnorm
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph * build_command_r() {
|
struct ggml_cgraph * build_command_r() {
|
||||||
|
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||||
|
@ -17215,6 +17357,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
{
|
{
|
||||||
result = llm.build_mamba(/* version */ 2);
|
result = llm.build_mamba(/* version */ 2);
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_BAMBA:
|
||||||
|
{
|
||||||
|
result = llm.build_bamba();
|
||||||
|
} break;
|
||||||
case LLM_ARCH_XVERSE:
|
case LLM_ARCH_XVERSE:
|
||||||
{
|
{
|
||||||
result = llm.build_xverse();
|
result = llm.build_xverse();
|
||||||
|
@ -20601,6 +20747,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
|
||||||
switch (model->arch) {
|
switch (model->arch) {
|
||||||
case LLM_ARCH_MAMBA:
|
case LLM_ARCH_MAMBA:
|
||||||
case LLM_ARCH_MAMBA2:
|
case LLM_ARCH_MAMBA2:
|
||||||
|
case LLM_ARCH_BAMBA:
|
||||||
case LLM_ARCH_RWKV6:
|
case LLM_ARCH_RWKV6:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue