wkv7 CUDA impl
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
6dcc21e7f5
commit
9cd24dd3eb
5 changed files with 197 additions and 90 deletions
|
@ -1893,6 +1893,9 @@ struct test_rwkv_wkv7 : public test_case {
|
|||
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
// Outputs may become NaN with long seqlen without these normalization
|
||||
a = ggml_l2_norm(ctx, a, 1e-7F);
|
||||
b = ggml_l2_norm(ctx, b, 1e-7F);
|
||||
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
|
||||
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
|
||||
return out;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue