Improve btype dequant kernel selection code, add error if type is unsupported
This commit is contained in:
parent
36bfb3c158
commit
137071003c
1 changed files with 33 additions and 26 deletions
|
@ -114,33 +114,40 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra
|
|||
cl_event events[4] = { NULL };
|
||||
|
||||
cl_kernel kernel;
|
||||
size_t global, local, size_qb;
|
||||
const bool dequant = btype >= 2 && btype < 6;
|
||||
if (dequant) {
|
||||
global = n * k;
|
||||
size_t global = n * k, local, size_qb;
|
||||
bool dequant;
|
||||
|
||||
switch (btype) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
kernel = kernel_q4_0;
|
||||
local = 16;
|
||||
size_qb = global * (sizeof(float) + local) / 32;
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
kernel = kernel_q4_1;
|
||||
local = 16;
|
||||
size_qb = global * (sizeof(float) * 2 + local) / 32;
|
||||
break;
|
||||
case GGML_TYPE_Q4_2:
|
||||
kernel = kernel_q4_2;
|
||||
local = 8;
|
||||
size_qb = global * (sizeof(short) + local) / 16;
|
||||
break;
|
||||
case GGML_TYPE_Q4_3:
|
||||
kernel = kernel_q4_3;
|
||||
local = 8;
|
||||
size_qb = global * (sizeof(short) * 2 + local) / 16;
|
||||
break;
|
||||
}
|
||||
switch (btype) {
|
||||
case GGML_TYPE_F32:
|
||||
dequant = false;
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
dequant = true;
|
||||
kernel = kernel_q4_0;
|
||||
local = 16;
|
||||
size_qb = global * (sizeof(float) + local) / 32;
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
dequant = true;
|
||||
kernel = kernel_q4_1;
|
||||
local = 16;
|
||||
size_qb = global * (sizeof(float) * 2 + local) / 32;
|
||||
break;
|
||||
case GGML_TYPE_Q4_2:
|
||||
dequant = true;
|
||||
kernel = kernel_q4_2;
|
||||
local = 8;
|
||||
size_qb = global * (sizeof(short) + local) / 16;
|
||||
break;
|
||||
case GGML_TYPE_Q4_3:
|
||||
dequant = true;
|
||||
kernel = kernel_q4_3;
|
||||
local = 8;
|
||||
size_qb = global * (sizeof(short) * 2 + local) / 16;
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
|
||||
abort();
|
||||
}
|
||||
|
||||
const size_t size_a = m * k * sizeof(float);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue