Fixed case order in ggml_vk_graph_compute

This commit is contained in:
niansa 2023-07-05 14:21:16 +02:00
parent 856b7589e9
commit 77ebe46966

View file

@ -1075,21 +1075,25 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{
ggml_vk_diag_mask_inf(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02);
} break;
case GGML_OP_NORM:
{
ggml_vk_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0));
} break;
case GGML_OP_RMS_NORM:
{
ggml_vk_rms_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0));
} break;
case GGML_OP_MUL_MAT:
{
if (src0->type == GGML_TYPE_F16
&& src1->type == GGML_TYPE_F32) {
&& src1->type == GGML_TYPE_F32) {
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb01, nb02, ne11, ne12, nb11, nb12, ne0, ne1);
break;
} else {
printf("Unsupported quantization: %u/%u\n", src0->type, src1->type);
}
}
case GGML_OP_NORM: {
ggml_vk_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0));
} break;
case GGML_OP_RMS_NORM: {
ggml_vk_rms_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0));
} break;
default:
default: {}
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
//GGML_ASSERT(false);
}