Improve btype dequant kernel selection code, add error if type is unsupported

This commit is contained in:
0cc4m 2023-04-25 19:40:54 +02:00
parent 36bfb3c158
commit 137071003c

View file

@ -114,33 +114,40 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra
cl_event events[4] = { NULL };
cl_kernel kernel;
size_t global, local, size_qb;
const bool dequant = btype >= 2 && btype < 6;
if (dequant) {
global = n * k;
size_t global = n * k, local, size_qb;
bool dequant;
switch (btype) {
case GGML_TYPE_Q4_0:
kernel = kernel_q4_0;
local = 16;
size_qb = global * (sizeof(float) + local) / 32;
break;
case GGML_TYPE_Q4_1:
kernel = kernel_q4_1;
local = 16;
size_qb = global * (sizeof(float) * 2 + local) / 32;
break;
case GGML_TYPE_Q4_2:
kernel = kernel_q4_2;
local = 8;
size_qb = global * (sizeof(short) + local) / 16;
break;
case GGML_TYPE_Q4_3:
kernel = kernel_q4_3;
local = 8;
size_qb = global * (sizeof(short) * 2 + local) / 16;
break;
}
switch (btype) {
case GGML_TYPE_F32:
dequant = false;
break;
case GGML_TYPE_Q4_0:
dequant = true;
kernel = kernel_q4_0;
local = 16;
size_qb = global * (sizeof(float) + local) / 32;
break;
case GGML_TYPE_Q4_1:
dequant = true;
kernel = kernel_q4_1;
local = 16;
size_qb = global * (sizeof(float) * 2 + local) / 32;
break;
case GGML_TYPE_Q4_2:
dequant = true;
kernel = kernel_q4_2;
local = 8;
size_qb = global * (sizeof(short) + local) / 16;
break;
case GGML_TYPE_Q4_3:
dequant = true;
kernel = kernel_q4_3;
local = 8;
size_qb = global * (sizeof(short) * 2 + local) / 16;
break;
default:
fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
abort();
}
const size_t size_a = m * k * sizeof(float);