From 137071003c12deb8b046890edb48612a582bd686 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Tue, 25 Apr 2023 19:40:54 +0200 Subject: [PATCH] Improve btype dequant kernel selection code, add error if type is unsupported --- ggml-opencl.cpp | 59 +++++++++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 8a87edb88..aa426fe3f 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -114,33 +114,40 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra cl_event events[4] = { NULL }; cl_kernel kernel; - size_t global, local, size_qb; - const bool dequant = btype >= 2 && btype < 6; - if (dequant) { - global = n * k; + size_t global = n * k, local, size_qb; + bool dequant; - switch (btype) { - case GGML_TYPE_Q4_0: - kernel = kernel_q4_0; - local = 16; - size_qb = global * (sizeof(float) + local) / 32; - break; - case GGML_TYPE_Q4_1: - kernel = kernel_q4_1; - local = 16; - size_qb = global * (sizeof(float) * 2 + local) / 32; - break; - case GGML_TYPE_Q4_2: - kernel = kernel_q4_2; - local = 8; - size_qb = global * (sizeof(short) + local) / 16; - break; - case GGML_TYPE_Q4_3: - kernel = kernel_q4_3; - local = 8; - size_qb = global * (sizeof(short) * 2 + local) / 16; - break; - } + switch (btype) { + case GGML_TYPE_F32: + dequant = false; + break; + case GGML_TYPE_Q4_0: + dequant = true; + kernel = kernel_q4_0; + local = 16; + size_qb = global * (sizeof(float) + local) / 32; + break; + case GGML_TYPE_Q4_1: + dequant = true; + kernel = kernel_q4_1; + local = 16; + size_qb = global * (sizeof(float) * 2 + local) / 32; + break; + case GGML_TYPE_Q4_2: + dequant = true; + kernel = kernel_q4_2; + local = 8; + size_qb = global * (sizeof(short) + local) / 16; + break; + case GGML_TYPE_Q4_3: + dequant = true; + kernel = kernel_q4_3; + local = 8; + size_qb = global * (sizeof(short) * 2 + local) / 16; + break; + default: + fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype); + abort(); } const size_t size_a = m * k * sizeof(float);