Improve btype dequant kernel selection code, add error if type is unsupported

This commit is contained in:
0cc4m 2023-04-25 19:40:54 +02:00
parent 36bfb3c158
commit 137071003c

View file

@ -114,33 +114,40 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra
cl_event events[4] = { NULL }; cl_event events[4] = { NULL };
cl_kernel kernel; cl_kernel kernel;
size_t global, local, size_qb; size_t global = n * k, local, size_qb;
const bool dequant = btype >= 2 && btype < 6; bool dequant;
if (dequant) {
global = n * k;
switch (btype) { switch (btype) {
case GGML_TYPE_Q4_0: case GGML_TYPE_F32:
kernel = kernel_q4_0; dequant = false;
local = 16; break;
size_qb = global * (sizeof(float) + local) / 32; case GGML_TYPE_Q4_0:
break; dequant = true;
case GGML_TYPE_Q4_1: kernel = kernel_q4_0;
kernel = kernel_q4_1; local = 16;
local = 16; size_qb = global * (sizeof(float) + local) / 32;
size_qb = global * (sizeof(float) * 2 + local) / 32; break;
break; case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2: dequant = true;
kernel = kernel_q4_2; kernel = kernel_q4_1;
local = 8; local = 16;
size_qb = global * (sizeof(short) + local) / 16; size_qb = global * (sizeof(float) * 2 + local) / 32;
break; break;
case GGML_TYPE_Q4_3: case GGML_TYPE_Q4_2:
kernel = kernel_q4_3; dequant = true;
local = 8; kernel = kernel_q4_2;
size_qb = global * (sizeof(short) * 2 + local) / 16; local = 8;
break; size_qb = global * (sizeof(short) + local) / 16;
} break;
case GGML_TYPE_Q4_3:
dequant = true;
kernel = kernel_q4_3;
local = 8;
size_qb = global * (sizeof(short) * 2 + local) / 16;
break;
default:
fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
abort();
} }
const size_t size_a = m * k * sizeof(float); const size_t size_a = m * k * sizeof(float);