RWKV: Some graph simplification

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-12-25 15:29:19 +08:00
parent f298f03970
commit 385b611d45
7 changed files with 36 additions and 34 deletions

View file

@ -513,6 +513,7 @@ extern "C" {
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV6,
GGML_OP_GATED_LINEAR_ATTENTION,
GGML_OP_UNARY,
@ -1873,6 +1874,14 @@ extern "C" {
struct ggml_tensor * td,
struct ggml_tensor * state);
GGML_API struct ggml_tensor * ggml_gated_linear_attn(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * state);
// custom operators
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

View file

@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos(
static void ggml_compute_forward_rwkv_wkv6_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const int64_t T = dst->src[1]->ne[3];
const int64_t T = dst->src[1]->ne[2];
const int64_t C = dst->ne[0];
const int64_t HEADS = dst->src[1]->ne[2];
const int64_t HEADS = dst->src[1]->ne[1];
const int64_t n_seqs = dst->src[5]->ne[1];
const int64_t head_size = C / HEADS;

View file

@ -73,9 +73,9 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
const float * s_d = (const float *)dst->src[5]->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[3];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[2];
const int64_t H = dst->src[0]->ne[1];
float * dst_d = (float *)dst->data;

View file

@ -107,9 +107,9 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s
float* dst_d = (float*)dst->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[3];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[2];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);

View file

@ -5624,9 +5624,9 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
}
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const size_t seq_length = dst->src[0]->ne[3];
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[2];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[5]->ne[1];
ggml_vk_op_f32_rwkv6(

View file

@ -968,6 +968,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GET_REL_POS",
"ADD_REL_POS",
"RWKV_WKV6",
"GATED_LINEAR_ATTN",
"UNARY",
@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW",
};
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1064,6 +1065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"get_rel_pos(x)",
"add_rel_pos(x)",
"rwkv_wkv6(k, v, r, tf, td, s)",
"gated_linear_attn(k, v, q, decay, s)",
"unary(x)",
@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)",
};
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -4640,15 +4642,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S = k->ne[0];
const int64_t H = k->ne[2];
const int64_t n_tokens = k->ne[3];
const int64_t H = k->ne[1];
const int64_t n_tokens = k->ne[2];
const int64_t n_seqs = state->ne[1];
{
GGML_ASSERT(k->ne[1] == 1);
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
// TODO: RWKV v4 and v5
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
}

View file

@ -142,11 +142,11 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
const int64_t H = 123;
const int64_t n_tokens = 123;
const int64_t n_seqs = 123;
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
ggml_tensor * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
ggml_tensor * tf = w;
ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
} break;
@ -3296,9 +3296,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
struct ggml_tensor * cur,
struct ggml_tensor * x_prev,
struct ggml_tensor ** wkv_state,
size_t head_count_kv,
bool key_decay,
bool skip_groupnorm) {
size_t head_count_kv) {
size_t n_embd = cur->ne[0];
size_t n_seq_tokens = cur->ne[1];
size_t n_seqs = cur->ne[2];
@ -3429,19 +3427,14 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
)
);
w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd));
w = ggml_add(ctx, w, layer->time_mix_decay);
w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
r = ggml_reshape_4d(ctx, r, 1, head_size, head_count, n_tokens);
k = ggml_reshape_4d(ctx, k, 1, head_size, head_count, n_tokens);
v = ggml_reshape_4d(ctx, v, 1, head_size, head_count, n_tokens);
w = ggml_reshape_3d(ctx, w, head_size, head_count, n_tokens);
if (is_qrwkv) {
// k = k * (1 - w)
k = ggml_sub(ctx, k, ggml_mul(ctx, k, w));
}
k = ggml_transpose(ctx, k);
struct ggml_tensor * wkv_output;
if (!layer->time_mix_first) {
@ -3452,7 +3445,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
if (!skip_groupnorm) {
if (!is_qrwkv) {
// group norm with head_count groups
cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
cur = ggml_norm(ctx, cur, 64e-5f);
@ -9880,7 +9873,7 @@ struct llm_build_context {
1
);
cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, layer->time_mix_first->ne[1], false, false));
cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, layer->time_mix_first->ne[1]));
ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(
gf,
@ -10003,7 +9996,7 @@ struct llm_build_context {
)
);
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.n_head_kv(), true, hparams.wkv_skip_groupnorm));
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.n_head_kv()));
cb(ffn_inp, "ffn_inp", il);
// feed-forward network