Fixed case order in ggml_vk_graph_compute

This commit is contained in:
niansa 2023-07-05 14:21:16 +02:00
parent 856b7589e9
commit 77ebe46966

View file

@ -1075,21 +1075,25 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{ {
ggml_vk_diag_mask_inf(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02); ggml_vk_diag_mask_inf(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02);
} break; } break;
case GGML_OP_NORM:
{
ggml_vk_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0));
} break;
case GGML_OP_RMS_NORM:
{
ggml_vk_rms_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0));
} break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
if (src0->type == GGML_TYPE_F16 if (src0->type == GGML_TYPE_F16
&& src1->type == GGML_TYPE_F32) { && src1->type == GGML_TYPE_F32) {
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb01, nb02, ne11, ne12, nb11, nb12, ne0, ne1); ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb01, nb02, ne11, ne12, nb11, nb12, ne0, ne1);
break; break;
} else {
printf("Unsupported quantization: %u/%u\n", src0->type, src1->type);
} }
} }
case GGML_OP_NORM: { default: {}
ggml_vk_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0));
} break;
case GGML_OP_RMS_NORM: {
ggml_vk_rms_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0));
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
//GGML_ASSERT(false); //GGML_ASSERT(false);
} }