wkv7 CUDA impl

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2025-01-16 15:50:56 +08:00
parent 6dcc21e7f5
commit 9cd24dd3eb
5 changed files with 197 additions and 90 deletions

View file

@ -1893,6 +1893,9 @@ struct test_rwkv_wkv7 : public test_case {
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
// Outputs may become NaN with long seqlen without these normalization
a = ggml_l2_norm(ctx, a, 1e-7F);
b = ggml_l2_norm(ctx, b, 1e-7F);
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
return out;