diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 6aab3ddae..232109762 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -1075,21 +1075,25 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph { ggml_vk_diag_mask_inf(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02); } break; + case GGML_OP_NORM: + { + ggml_vk_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0)); + } break; + case GGML_OP_RMS_NORM: + { + ggml_vk_rms_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0)); + } break; case GGML_OP_MUL_MAT: { if (src0->type == GGML_TYPE_F16 - && src1->type == GGML_TYPE_F32) { + && src1->type == GGML_TYPE_F32) { ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb01, nb02, ne11, ne12, nb11, nb12, ne0, ne1); break; + } else { + printf("Unsupported quantization: %u/%u\n", src0->type, src1->type); } } - case GGML_OP_NORM: { - ggml_vk_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0)); - } break; - case GGML_OP_RMS_NORM: { - ggml_vk_rms_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0)); - } break; - default: + default: {} fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); //GGML_ASSERT(false); }