update Q formats

This commit is contained in:
Henri Vasserman 2023-05-19 23:52:35 +03:00
parent 057c9b7dc8
commit 6df8e93234
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986

View file

@ -19,14 +19,14 @@ typedef uint uint32_t;
struct __attribute__ ((packed)) block_q4_0 struct __attribute__ ((packed)) block_q4_0
{ {
float d; half d;
uint8_t qs[16]; /* QK4_0 / 2 */ uint8_t qs[16]; /* QK4_0 / 2 */
}; };
struct __attribute__ ((packed)) block_q4_1 struct __attribute__ ((packed)) block_q4_1
{ {
float d; half d;
float m; half m;
uint8_t qs[16]; /* QK4_1 / 2 */ uint8_t qs[16]; /* QK4_1 / 2 */
}; };
@ -47,7 +47,7 @@ struct __attribute__ ((packed)) block_q5_1
struct __attribute__ ((packed)) block_q8_0 struct __attribute__ ((packed)) block_q8_0
{ {
float d; half d;
int8_t qs[32]; /* QK8_0 */ int8_t qs[32]; /* QK8_0 */
}; };
@ -56,7 +56,7 @@ __kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float*
const uint i = get_global_id(0) / 32; /* QK4_0 */ const uint i = get_global_id(0) / 32; /* QK4_0 */
const uint j = get_local_id(0); const uint j = get_local_id(0);
const float d = x[i].d; const float d = vload_half(0, (__global half*) &x[i].d);
const int x0 = (x[i].qs[j] & 0xf) - 8; const int x0 = (x[i].qs[j] & 0xf) - 8;
const int x1 = (x[i].qs[j] >> 4) - 8; const int x1 = (x[i].qs[j] >> 4) - 8;
@ -69,8 +69,8 @@ __kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float*
const uint i = get_global_id(0) / 32; /* QK4_1 */ const uint i = get_global_id(0) / 32; /* QK4_1 */
const uint j = get_local_id(0); const uint j = get_local_id(0);
const float d = x[i].d; const float d = vload_half(0, (__global half*) &x[i].d);
const float m = x[i].m; const float m = vload_half(0, (__global half*) &x[i].m);
const int x0 = (x[i].qs[j] & 0xf); const int x0 = (x[i].qs[j] & 0xf);
const int x1 = (x[i].qs[j] >> 4); const int x1 = (x[i].qs[j] >> 4);
@ -120,7 +120,7 @@ __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float*
const uint i = get_global_id(0) / 32; /* QK8_0 */ const uint i = get_global_id(0) / 32; /* QK8_0 */
const uint j = get_local_id(0); const uint j = get_local_id(0);
const float d = x[i].d; const float d = vload_half(0, (__global half*) &x[i].d);
y[i*32 + j] = x[i].qs[j]*d; y[i*32 + j] = x[i].qs[j]*d;
} }
@ -358,13 +358,13 @@ void ggml_cl_sgemm_wrapper(
dequant = true; dequant = true;
kernel = kernel_q4_0; kernel = kernel_q4_0;
local = 16; local = 16;
size_qb = global * (sizeof(float) + local) / 32; size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
dequant = true; dequant = true;
kernel = kernel_q4_1; kernel = kernel_q4_1;
local = 16; local = 16;
size_qb = global * (sizeof(float) * 2 + local) / 32; size_qb = global * (sizeof(ggml_fp16_t) * 2 + local) / 32;
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
dequant = true; dequant = true;
@ -382,7 +382,7 @@ void ggml_cl_sgemm_wrapper(
dequant = true; dequant = true;
kernel = kernel_q8_0; kernel = kernel_q8_0;
local = 32; local = 32;
size_qb = global * (sizeof(float) + local) / 32; size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
break; break;
default: default:
fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype); fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);