WIP: Add support for rwkv v7

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2025-01-15 20:43:23 +08:00
parent 5445300758
commit 6dcc21e7f5
14 changed files with 952 additions and 48 deletions

View file

@ -1868,6 +1868,37 @@ struct test_rwkv_wkv6 : public test_case {
}
};
// GGML_OP_RWKV_WKV7
struct test_rwkv_wkv7 : public test_case {
const ggml_type type;
const int64_t head_count;
const int64_t head_size;
const int64_t n_seq_tokens;
const int64_t n_seqs;
std::string vars() override {
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
}
test_rwkv_wkv7(ggml_type type = GGML_TYPE_F32,
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t n_tokens = n_seq_tokens * n_seqs;
ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * w = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
return out;
}
};
// GGML_OP_GATED_LINEAR_ATTN
struct test_gla : public test_case {
const ggml_type type;
@ -4026,6 +4057,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 1));
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 128, 4));
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));