llama: add support for QRWKV6 model architecture (#11001)
llama: add support for QRWKV6 model architecture (#11001) * WIP: Add support for RWKV6Qwen2 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * RWKV: Some graph simplification Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add support for RWKV6Qwen2 with cpu and cuda GLA Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * RWKV6[QWEN2]: Concat lerp weights together to reduce cpu overhead Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix some typos Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * code format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix wkv test & add gla test Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix cuda warning Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update README.md Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update ggml/src/ggml-cuda/gla.cu Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Fix fused lerp weights loading with RWKV6 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * better sanity check skipping for QRWKV6 in llama-quant thanks @compilade Signed-off-by: Molly Sophia <mollysophia379@gmail.com> Co-authored-by: compilade <git@compilade.net> --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
parent
c6860cc734
commit
ee7136c6d1
23 changed files with 862 additions and 124 deletions
|
@ -968,6 +968,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"GET_REL_POS",
|
||||
"ADD_REL_POS",
|
||||
"RWKV_WKV6",
|
||||
"GATED_LINEAR_ATTN",
|
||||
|
||||
"UNARY",
|
||||
|
||||
|
@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"OPT_STEP_ADAMW",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
@ -1064,6 +1065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"get_rel_pos(x)",
|
||||
"add_rel_pos(x)",
|
||||
"rwkv_wkv6(k, v, r, tf, td, s)",
|
||||
"gated_linear_attn(k, v, q, gate, s)",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
|
@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"adamw(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
@ -4629,15 +4631,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
|
|||
GGML_ASSERT(ggml_is_contiguous(state));
|
||||
|
||||
const int64_t S = k->ne[0];
|
||||
const int64_t H = k->ne[2];
|
||||
const int64_t n_tokens = k->ne[3];
|
||||
const int64_t H = k->ne[1];
|
||||
const int64_t n_tokens = k->ne[2];
|
||||
const int64_t n_seqs = state->ne[1];
|
||||
{
|
||||
GGML_ASSERT(k->ne[1] == 1);
|
||||
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
|
||||
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
|
||||
// TODO: RWKV v4 and v5
|
||||
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
|
||||
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
|
||||
GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
|
||||
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
||||
}
|
||||
|
||||
|
@ -4656,6 +4656,49 @@ struct ggml_tensor * ggml_rwkv_wkv6(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_gated_linear_attn
|
||||
|
||||
struct ggml_tensor * ggml_gated_linear_attn(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * state,
|
||||
float scale) {
|
||||
GGML_ASSERT(ggml_is_contiguous(k));
|
||||
GGML_ASSERT(ggml_is_contiguous(v));
|
||||
GGML_ASSERT(ggml_is_contiguous(q));
|
||||
GGML_ASSERT(ggml_is_contiguous(g));
|
||||
GGML_ASSERT(ggml_is_contiguous(state));
|
||||
|
||||
const int64_t S = k->ne[0];
|
||||
const int64_t H = k->ne[1];
|
||||
const int64_t n_tokens = k->ne[2];
|
||||
const int64_t n_seqs = state->ne[1];
|
||||
{
|
||||
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
|
||||
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
||||
}
|
||||
|
||||
// concat output and new_state
|
||||
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
ggml_set_op_params_f32(result, 0, scale);
|
||||
|
||||
result->op = GGML_OP_GATED_LINEAR_ATTN;
|
||||
result->src[0] = k;
|
||||
result->src[1] = v;
|
||||
result->src[2] = q;
|
||||
result->src[3] = g;
|
||||
result->src[4] = state;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_unary
|
||||
|
||||
static struct ggml_tensor * ggml_unary_impl(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue