llama : fix rwkv inference (#11618)

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2025-02-03 20:17:50 +08:00 committed by GitHub
parent 74b0807245
commit 1eca8916b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 409 additions and 349 deletions

View file

@ -1970,6 +1970,228 @@ ggml_tensor * llama_context::build_mamba_layer(
}
ggml_tensor * llama_context::build_rwkv_token_shift_load(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
const auto & hparams = model.hparams;
const auto token_shift_count = hparams.token_shift_count;
const auto & n_tokens = ubatch.n_tokens;
const int64_t n_seqs = ubatch.n_seqs;
struct ggml_tensor * token_shift_all = kv_self.k_l[il];
struct ggml_tensor * token_shift = build_copy_mask_state(
ctx0, graph, token_shift_all, state_copy, state_mask,
n_tokens, hparams.n_embd_k_s(), n_seqs, worst_case);
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
return token_shift;
}
ggml_tensor * llama_context::build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
const auto & hparams = model.hparams;
const auto token_shift_count = hparams.token_shift_count;
const auto n_embd = hparams.n_embd;
const auto & n_tokens = ubatch.n_tokens;
const int64_t n_seqs = ubatch.n_seqs;
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
return ggml_cpy(
ctx0,
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
);
}
ggml_tensor * llama_context::build_rwkv6_time_mix(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
const auto & hparams = model.hparams;
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
const auto n_embd = hparams.n_embd;
const auto head_size = hparams.wkv_head_size;
const auto n_head = n_embd / head_size;
const auto n_head_kv = hparams.n_head_kv(il);
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
const auto layer = &model.layers[il];
bool is_qrwkv = layer->time_mix_first == nullptr;
struct ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
struct ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->time_mix_lerp_x), cur);
xxx = ggml_reshape_4d(
ctx0,
ggml_tanh(
ctx0,
ggml_mul_mat(ctx0, layer->time_mix_w1, xxx)
),
layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
);
xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2));
xxx = ggml_mul_mat(
ctx0,
ggml_reshape_4d(
ctx0,
layer->time_mix_w2,
layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
),
xxx
);
struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
if (layer->time_mix_lerp_fused) {
// fusing these weights makes some performance improvement
sx = ggml_reshape_3d(ctx0, sx, n_embd, 1, n_tokens);
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer->time_mix_lerp_fused), sx), cur);
xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
} else {
// for backward compatibility
xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer->time_mix_lerp_w), sx), cur);
xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer->time_mix_lerp_k), sx), cur);
xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer->time_mix_lerp_v), sx), cur);
xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer->time_mix_lerp_r), sx), cur);
xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer->time_mix_lerp_g), sx), cur);
}
struct ggml_tensor * r = build_lora_mm(ctx0, layer->time_mix_receptance, xr);
struct ggml_tensor * k = build_lora_mm(ctx0, layer->time_mix_key, xk);
struct ggml_tensor * v = build_lora_mm(ctx0, layer->time_mix_value, xv);
if (layer->time_mix_receptance_b) {
r = ggml_add(ctx0, r, layer->time_mix_receptance_b);
}
if (layer->time_mix_key_b) {
k = ggml_add(ctx0, k, layer->time_mix_key_b);
}
if (layer->time_mix_value_b) {
v = ggml_add(ctx0, v, layer->time_mix_value_b);
}
struct ggml_tensor * g = build_lora_mm(ctx0, layer->time_mix_gate, xg);
if (is_qrwkv) {
g = ggml_sigmoid(ctx0, g);
} else {
g = ggml_silu(ctx0, g);
}
if (n_head_kv != 0 && n_head_kv != n_head) {
GGML_ASSERT(n_head % n_head_kv == 0);
k = ggml_reshape_4d(ctx0, k, head_size, 1, n_head_kv, n_tokens);
v = ggml_reshape_4d(ctx0, v, head_size, 1, n_head_kv, n_tokens);
struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens);
k = ggml_repeat(ctx0, k, tmp);
v = ggml_repeat(ctx0, v, tmp);
}
k = ggml_reshape_3d(ctx0, k, head_size, n_head, n_tokens);
v = ggml_reshape_3d(ctx0, v, head_size, n_head, n_tokens);
r = ggml_reshape_3d(ctx0, r, head_size, n_head, n_tokens);
struct ggml_tensor * w = ggml_mul_mat(
ctx0,
layer->time_mix_decay_w2,
ggml_tanh(
ctx0,
ggml_mul_mat(ctx0, layer->time_mix_decay_w1, xw)
)
);
w = ggml_add(ctx0, w, layer->time_mix_decay);
w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w)));
w = ggml_reshape_3d(ctx0, w, head_size, n_head, n_tokens);
if (is_qrwkv) {
// k = k * (1 - w)
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
}
struct ggml_tensor * wkv_state = build_copy_mask_state(
ctx0, graph, kv_self.v_l[il], state_copy, state_mask,
n_tokens, hparams.n_embd_v_s(), n_seqs, worst_case);
struct ggml_tensor * wkv_output;
if (is_qrwkv) {
wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, wkv_state, pow(head_size, -0.5f));
} else {
wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer->time_mix_first, w, wkv_state);
}
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
ggml_build_forward_expand(
graph,
ggml_cpy(
ctx0,
wkv_state,
ggml_view_1d(
ctx0,
kv_self.v_l[il],
hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
)
)
);
if (!is_qrwkv) {
// group norm with head_count groups
cur = ggml_reshape_3d(ctx0, cur, n_embd / n_head, n_head, n_tokens);
cur = ggml_norm(ctx0, cur, 64e-5f);
// Convert back to regular vectors.
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer->time_mix_ln), layer->time_mix_ln_b);
} else {
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
}
cur = ggml_mul(ctx0, cur, g);
cur = build_lora_mm(ctx0, layer->time_mix_output, cur);
return cur;
}
// llama output
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {

View file

@ -248,6 +248,33 @@ struct llama_context {
int il,
bool worst_case);
ggml_tensor * build_rwkv_token_shift_load(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case);
ggml_tensor * build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il,
bool worst_case);
ggml_tensor * build_rwkv6_time_mix(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case);
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]

View file

@ -574,175 +574,34 @@ struct llm_build_context {
return cur;
}
//struct ggml_tensor * build_rwkv6_time_mix(
// const struct llama_layer * layer,
// struct ggml_tensor * cur,
// struct ggml_tensor * x_prev,
// struct ggml_tensor ** wkv_state,
// size_t wkv_head_size,
// size_t head_count_kv) {
// size_t n_embd = cur->ne[0];
// size_t n_seq_tokens = cur->ne[1];
// size_t n_seqs = cur->ne[2];
struct ggml_tensor * build_rwkv_channel_mix(
const struct llama_layer * layer,
struct ggml_tensor * cur,
struct ggml_tensor * x_prev,
const llm_arch arch) {
struct ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
switch (arch) {
case LLM_ARCH_RWKV6:
{
struct ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
struct ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur);
// size_t head_size = wkv_head_size;
// size_t head_count = n_embd / head_size;
struct ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr));
struct ggml_tensor * k = ggml_sqr(
ctx0,
ggml_relu(
ctx0,
build_lora_mm(layer->channel_mix_key, xk)
)
);
cur = ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k));
} break;
default:
GGML_ABORT("fatal error");
}
// size_t n_tokens = n_seqs * n_seq_tokens;
// bool is_qrwkv = layer->time_mix_first == nullptr;
// struct ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
// sx = ggml_reshape_2d(ctx0, sx, n_embd, n_tokens);
// cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
// struct ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->time_mix_lerp_x), cur);
// xxx = ggml_reshape_4d(
// ctx0,
// ggml_tanh(
// ctx0,
// ggml_mul_mat(ctx0, layer->time_mix_w1, xxx)
// ),
// layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
// );
// xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2));
// xxx = ggml_mul_mat(
// ctx0,
// ggml_reshape_4d(
// ctx0,
// layer->time_mix_w2,
// layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
// ),
// xxx
// );
// struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
// if (layer->time_mix_lerp_fused) {
// // fusing these weights makes some performance improvement
// sx = ggml_reshape_3d(ctx0, sx, n_embd, 1, n_tokens);
// cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
// xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer->time_mix_lerp_fused), sx), cur);
// xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
// xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
// xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
// xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
// xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
// } else {
// // for backward compatibility
// xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
// xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
// xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
// xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
// xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
// xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer->time_mix_lerp_w), sx), cur);
// xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer->time_mix_lerp_k), sx), cur);
// xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer->time_mix_lerp_v), sx), cur);
// xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer->time_mix_lerp_r), sx), cur);
// xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer->time_mix_lerp_g), sx), cur);
// }
// struct ggml_tensor * r = build_lora_mm(layer->time_mix_receptance, xr);
// struct ggml_tensor * k = build_lora_mm(layer->time_mix_key, xk);
// struct ggml_tensor * v = build_lora_mm(layer->time_mix_value, xv);
// if (layer->time_mix_receptance_b) {
// r = ggml_add(ctx0, r, layer->time_mix_receptance_b);
// }
// if (layer->time_mix_key_b) {
// k = ggml_add(ctx0, k, layer->time_mix_key_b);
// }
// if (layer->time_mix_value_b) {
// v = ggml_add(ctx0, v, layer->time_mix_value_b);
// }
// struct ggml_tensor * g = build_lora_mm(layer->time_mix_gate, xg);
// if (is_qrwkv) {
// g = ggml_sigmoid(ctx0, g);
// } else {
// g = ggml_silu(ctx0, g);
// }
// if (head_count_kv != head_count) {
// GGML_ASSERT(head_count % head_count_kv == 0);
// k = ggml_reshape_4d(ctx0, k, head_size, 1, head_count_kv, n_tokens);
// v = ggml_reshape_4d(ctx0, v, head_size, 1, head_count_kv, n_tokens);
// struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, head_count / head_count_kv, head_count_kv, n_tokens);
// k = ggml_repeat(ctx0, k, tmp);
// v = ggml_repeat(ctx0, v, tmp);
// }
// k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens);
// v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
// r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens);
// struct ggml_tensor * w = ggml_mul_mat(
// ctx0,
// layer->time_mix_decay_w2,
// ggml_tanh(
// ctx0,
// ggml_mul_mat(ctx0, layer->time_mix_decay_w1, xw)
// )
// );
// w = ggml_add(ctx0, w, layer->time_mix_decay);
// w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w)));
// w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens);
// if (is_qrwkv) {
// // k = k * (1 - w)
// k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
// }
// struct ggml_tensor * wkv_output;
// if (!layer->time_mix_first) {
// wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, *wkv_state, pow(head_size, -0.5f));
// } else {
// wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer->time_mix_first, w, *wkv_state);
// }
// cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
// *wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
// if (!is_qrwkv) {
// // group norm with head_count groups
// cur = ggml_reshape_3d(ctx0, cur, n_embd / head_count, head_count, n_tokens);
// cur = ggml_norm(ctx0, cur, 64e-5f);
// // Convert back to regular vectors.
// cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
// cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer->time_mix_ln), layer->time_mix_ln_b);
// } else {
// cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
// }
// cur = ggml_mul(ctx0, cur, g);
// cur = build_lora_mm(layer->time_mix_output, cur);
// return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
//}
//struct ggml_tensor * build_rwkv6_channel_mix(
// const struct llama_layer * layer,
// struct ggml_tensor * cur,
// struct ggml_tensor * x_prev) {
// struct ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
// struct ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
// struct ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur);
// struct ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr));
// struct ggml_tensor * k = ggml_sqr(
// ctx0,
// ggml_relu(
// ctx0,
// build_lora_mm(layer->channel_mix_key, xk)
// )
// );
// return ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k));
//}
return cur;
}
struct ggml_cgraph * build_k_shift() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
@ -6935,226 +6794,178 @@ struct llm_build_context {
return gf;
}
//ggml_cgraph * build_rwkv6() {
// struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
ggml_cgraph * build_rwkv6() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
// // Token shift state dimensions should be 2 * n_emb
// GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
GGML_ASSERT(hparams.token_shift_count == 2);
// const int64_t n_seqs = ubatch.n_seqs;
// const int64_t n_seq_tokens = ubatch.n_seq_tokens;
// const int64_t n_tokens = ubatch.n_tokens;
// GGML_ASSERT(n_seqs != 0);
// GGML_ASSERT(ubatch.equal_seqs);
// GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
// struct ggml_tensor * cur;
// struct ggml_tensor * inpL;
// struct ggml_tensor * state_copy = build_inp_s_copy();
// struct ggml_tensor * state_mask = build_inp_s_mask();
inpL = build_inp_embd(model.tok_embd);
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
// inpL = build_inp_embd(model.tok_embd);
// inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
struct ggml_tensor * state_copy = lctx.build_inp_s_copy(ctx0, worst_case);
struct ggml_tensor * state_mask = lctx.build_inp_s_mask(ctx0, worst_case);
// for (int il = 0; il < n_layer; ++il) {
// const llama_layer * layer = &model.layers[il];
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
const auto n_seqs = ubatch.n_seqs;
// // (ab)using the KV cache to store the states
// struct ggml_tensor * token_shift = build_copy_mask_state(
// gf, kv_self.k_l[il], state_copy, state_mask,
// hparams.n_embd_k_s(), n_seqs);
for (int il = 0; il < n_layer; ++il) {
const llama_layer * layer = &model.layers[il];
// struct ggml_tensor * wkv_states = build_copy_mask_state(
// gf, kv_self.v_l[il], state_copy, state_mask,
// hparams.n_embd_v_s(), n_seqs);
struct ggml_tensor * token_shift = lctx.build_rwkv_token_shift_load(
ctx0, gf, state_copy, state_mask, ubatch, il, worst_case
);
// cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
// token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs);
struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
// struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
// struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
struct ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il);
cb(att_norm, "attn_norm", il);
// struct ggml_tensor * x_norm_att = build_norm(cur, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il);
// struct ggml_tensor * x_prev = ggml_concat(
// ctx0,
// att_shift,
// ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
// 1
// );
struct ggml_tensor * x_prev = ggml_concat(
ctx0,
att_shift,
ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
1
);
// cur = ggml_add(ctx0, cur, build_rwkv6_time_mix(layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, n_embd / hparams.wkv_head_size));
// ggml_build_forward_expand(gf, cur);
// ggml_build_forward_expand(
// gf,
// ggml_cpy(
// ctx0,
// wkv_states,
// ggml_view_1d(
// ctx0,
// kv_self.v_l[il],
// hparams.n_embd_v_s() * n_seqs,
// hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
// )
// )
// );
cur = lctx.build_rwkv6_time_mix(ctx0, gf, att_norm, x_prev, state_copy, state_mask, ubatch, il, worst_case);
// struct ggml_tensor * x_norm_ffn = build_norm(cur, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il);
// x_prev = ggml_concat(
// ctx0,
// ffn_shift,
// ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
// 1
// );
// cur = ggml_add(ctx0, cur, build_rwkv6_channel_mix(layer, x_norm_ffn, x_prev));
// ggml_build_forward_expand(gf, cur);
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
// struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
// struct ggml_tensor * last_norm_ffn = ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_ffn));
struct ggml_tensor * ffn_norm = build_norm(ffn_inp, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il);
cb(ffn_norm, "ffn_norm", il);
// token_shift = ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1);
x_prev = ggml_concat(
ctx0,
ffn_shift,
ggml_view_3d(ctx0, ffn_norm, n_embd, n_seq_tokens - 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], 0),
1
);
// ggml_build_forward_expand(
// gf,
// ggml_cpy(
// ctx0,
// ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0),
// ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
// )
// );
cur = build_rwkv_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
cur = ggml_add(ctx0, cur, ffn_inp);
// if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
// cur = ggml_scale(ctx0, cur, 0.5F);
// }
token_shift = ggml_concat(ctx0,
ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)),
ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)),
1
);
ggml_build_forward_expand(gf, lctx.build_rwkv_token_shift_store(ctx0, token_shift, ubatch, il, worst_case));
// cur = lctx.cvec.apply_to(ctx0, cur, il);
// cb(cur, "l_out", il);
if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
cur = ggml_scale(ctx0, cur, 0.5F);
}
// // input for next layer
// inpL = cur;
// }
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// cur = inpL;
// struct ggml_tensor * inp_out_ids = build_inp_out_ids();
// cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
// cur = ggml_get_rows(ctx0, cur, inp_out_ids);
// input for next layer
inpL = cur;
}
// cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1);
// cb(cur, "result_norm", -1);
cur = inpL;
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
// cur = build_lora_mm(model.output, cur);
// cb(cur, "result_output", -1);
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1);
cb(cur, "result_norm", -1);
// ggml_build_forward_expand(gf, cur);
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
// return gf;
//}
ggml_build_forward_expand(gf, cur);
return gf;
}
// ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
//ggml_cgraph * build_rwkv6qwen2() {
// struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
ggml_cgraph * build_rwkv6qwen2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
// GGML_ASSERT(n_embd == hparams.n_embd_k_s());
GGML_ASSERT(n_embd == hparams.n_embd_k_s());
// const int64_t n_seqs = ubatch.n_seqs;
// const int64_t n_seq_tokens = ubatch.n_seq_tokens;
// const int64_t n_tokens = ubatch.n_tokens;
// GGML_ASSERT(n_seqs != 0);
// GGML_ASSERT(ubatch.equal_seqs);
// GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
// struct ggml_tensor * cur;
// struct ggml_tensor * inpL;
// struct ggml_tensor * state_copy = build_inp_s_copy();
// struct ggml_tensor * state_mask = build_inp_s_mask();
inpL = build_inp_embd(model.tok_embd);
// inpL = build_inp_embd(model.tok_embd);
struct ggml_tensor * state_copy = lctx.build_inp_s_copy(ctx0, worst_case);
struct ggml_tensor * state_mask = lctx.build_inp_s_mask(ctx0, worst_case);
// for (int il = 0; il < n_layer; ++il) {
// const llama_layer * layer = &model.layers[il];
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
const auto n_seqs = ubatch.n_seqs;
// // (ab)using the KV cache to store the states
// struct ggml_tensor * token_shift = build_copy_mask_state(
// gf, kv_self.k_l[il], state_copy, state_mask,
// hparams.n_embd_k_s(), n_seqs);
inpL = build_inp_embd(model.tok_embd);
// struct ggml_tensor * wkv_states = build_copy_mask_state(
// gf, kv_self.v_l[il], state_copy, state_mask,
// hparams.n_embd_v_s(), n_seqs);
for (int il = 0; il < n_layer; ++il) {
const llama_layer * layer = &model.layers[il];
// cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
// token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs);
struct ggml_tensor * token_shift = lctx.build_rwkv_token_shift_load(
ctx0, gf, state_copy, state_mask, ubatch, il, worst_case
);
// struct ggml_tensor * x_norm_att = build_norm(cur, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
// struct ggml_tensor * x_prev = ggml_concat(
// ctx0,
// token_shift,
// ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
// 1
// );
struct ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
cb(att_norm, "attn_norm", il);
// struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
// ggml_build_forward_expand(
// gf,
// ggml_cpy(
// ctx0,
// ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0),
// ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
struct ggml_tensor * x_prev = ggml_concat(
ctx0,
token_shift,
ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
1
);
// struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, build_rwkv6_time_mix(layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, hparams.n_head_kv()));
// ggml_build_forward_expand(gf, ffn_inp);
// ggml_build_forward_expand(
// gf,
// ggml_cpy(
// ctx0,
// wkv_states,
// ggml_view_1d(
// ctx0,
// kv_self.v_l[il],
// hparams.n_embd_v_s() * n_seqs,
// hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
// )
// )
// );
cur = lctx.build_rwkv6_time_mix(ctx0, gf, att_norm, x_prev, state_copy, state_mask, ubatch, il, worst_case);
// cb(ffn_inp, "ffn_inp", il);
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
ggml_build_forward_expand(gf, lctx.build_rwkv_token_shift_store(ctx0, token_shift, ubatch, il, worst_case));
// // feed-forward network
// cur = build_norm(ffn_inp,
// model.layers[il].ffn_norm, NULL,
// LLM_NORM_RMS, il);
// cb(cur, "ffn_norm", il);
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
// cur = build_ffn(cur,
// model.layers[il].ffn_up, NULL, NULL,
// model.layers[il].ffn_gate, NULL, NULL,
// model.layers[il].ffn_down, NULL, NULL,
// NULL,
// LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
// cb(cur, "ffn_out", il);
// feed-forward network
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
// cur = ggml_add(ctx0, cur, ffn_inp);
// cur = lctx.cvec.apply_to(ctx0, cur, il);
// cb(cur, "l_out", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
// // input for next layer
// inpL = cur;
// }
cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// cur = inpL;
// struct ggml_tensor * inp_out_ids = build_inp_out_ids();
// cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
// cur = ggml_get_rows(ctx0, cur, inp_out_ids);
// input for next layer
inpL = cur;
}
// cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1);
// cb(cur, "result_norm", -1);
cur = inpL;
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
// cur = build_lora_mm(model.output, cur);
// cb(cur, "result_output", -1);
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
// ggml_build_forward_expand(gf, cur);
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
// return gf;
//}
ggml_build_forward_expand(gf, cur);
return gf;
}
// ref: https://github.com/facebookresearch/chameleon
// based on the original build_llama() function, changes:
@ -7726,14 +7537,14 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_exaone();
} break;
//case LLM_ARCH_RWKV6:
// {
// result = llm.build_rwkv6();
// } break;
//case LLM_ARCH_RWKV6QWEN2:
// {
// result = llm.build_rwkv6qwen2();
// } break;
case LLM_ARCH_RWKV6:
{
result = llm.build_rwkv6();
} break;
case LLM_ARCH_RWKV6QWEN2:
{
result = llm.build_rwkv6qwen2();
} break;
case LLM_ARCH_CHAMELEON:
{
result = llm.build_chameleon();