RWKV_WKV6 Vulkan op tests passed

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-12-13 17:19:23 +08:00
parent 4651f5e2f2
commit 77fe4fd982
3 changed files with 6 additions and 5 deletions

View file

@ -1961,7 +1961,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
"main",
7,
sizeof(vk_op_rwkv_wkv6_push_constants),
{64, 1, 1}, // work group
{1, 1, 1}, // work group
{device->subgroup_size},
1
);
@ -8344,11 +8344,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
const float * op_params = (const float *)tensor->op_params;
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
tensor->src[4], tensor->src[5]);
}
// else if (tensor->op == GGML_OP_RWKV_WKV6) {
// tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
// tensor->src[4], tensor->src[5]);
// }
else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
GGML_ABORT("fatal error");

View file

@ -479,6 +479,8 @@ void process_shaders() {
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"C_TYPE", "float"}, {"D_TYPE", "float"}, {"E_TYPE", "float"}, {"F_TYPE", "float"}, {"S_TYPE", "float"}}));
for (auto &c : compiles) {
c.wait();
}