RWKV_WKV6 Vulkan op tests passed
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
4651f5e2f2
commit
77fe4fd982
3 changed files with 6 additions and 5 deletions
|
@ -1961,7 +1961,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
"main",
|
||||
7,
|
||||
sizeof(vk_op_rwkv_wkv6_push_constants),
|
||||
{64, 1, 1}, // work group
|
||||
{1, 1, 1}, // work group
|
||||
{device->subgroup_size},
|
||||
1
|
||||
);
|
||||
|
@ -8344,11 +8344,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
||||
const float * op_params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
|
||||
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
||||
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
|
||||
tensor->src[4], tensor->src[5]);
|
||||
}
|
||||
// else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
||||
// tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
|
||||
// tensor->src[4], tensor->src[5]);
|
||||
// }
|
||||
else {
|
||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
|
@ -479,6 +479,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"C_TYPE", "float"}, {"D_TYPE", "float"}, {"E_TYPE", "float"}, {"F_TYPE", "float"}, {"S_TYPE", "float"}}));
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue