Improve btype dequant kernel selection code, add error if type is unsupported
This commit is contained in:
parent
36bfb3c158
commit
137071003c
1 changed files with 33 additions and 26 deletions
|
@ -114,33 +114,40 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra
|
||||||
cl_event events[4] = { NULL };
|
cl_event events[4] = { NULL };
|
||||||
|
|
||||||
cl_kernel kernel;
|
cl_kernel kernel;
|
||||||
size_t global, local, size_qb;
|
size_t global = n * k, local, size_qb;
|
||||||
const bool dequant = btype >= 2 && btype < 6;
|
bool dequant;
|
||||||
if (dequant) {
|
|
||||||
global = n * k;
|
|
||||||
|
|
||||||
switch (btype) {
|
switch (btype) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_F32:
|
||||||
kernel = kernel_q4_0;
|
dequant = false;
|
||||||
local = 16;
|
break;
|
||||||
size_qb = global * (sizeof(float) + local) / 32;
|
case GGML_TYPE_Q4_0:
|
||||||
break;
|
dequant = true;
|
||||||
case GGML_TYPE_Q4_1:
|
kernel = kernel_q4_0;
|
||||||
kernel = kernel_q4_1;
|
local = 16;
|
||||||
local = 16;
|
size_qb = global * (sizeof(float) + local) / 32;
|
||||||
size_qb = global * (sizeof(float) * 2 + local) / 32;
|
break;
|
||||||
break;
|
case GGML_TYPE_Q4_1:
|
||||||
case GGML_TYPE_Q4_2:
|
dequant = true;
|
||||||
kernel = kernel_q4_2;
|
kernel = kernel_q4_1;
|
||||||
local = 8;
|
local = 16;
|
||||||
size_qb = global * (sizeof(short) + local) / 16;
|
size_qb = global * (sizeof(float) * 2 + local) / 32;
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_3:
|
case GGML_TYPE_Q4_2:
|
||||||
kernel = kernel_q4_3;
|
dequant = true;
|
||||||
local = 8;
|
kernel = kernel_q4_2;
|
||||||
size_qb = global * (sizeof(short) * 2 + local) / 16;
|
local = 8;
|
||||||
break;
|
size_qb = global * (sizeof(short) + local) / 16;
|
||||||
}
|
break;
|
||||||
|
case GGML_TYPE_Q4_3:
|
||||||
|
dequant = true;
|
||||||
|
kernel = kernel_q4_3;
|
||||||
|
local = 8;
|
||||||
|
size_qb = global * (sizeof(short) * 2 + local) / 16;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
|
||||||
|
abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t size_a = m * k * sizeof(float);
|
const size_t size_a = m * k * sizeof(float);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue