update Q formats

This commit is contained in:
Henri Vasserman 2023-05-19 23:52:35 +03:00
parent 057c9b7dc8
commit 6df8e93234
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986

View file

@ -19,14 +19,14 @@ typedef uint uint32_t;
struct __attribute__ ((packed)) block_q4_0
{
float d;
half d;
uint8_t qs[16]; /* QK4_0 / 2 */
};
struct __attribute__ ((packed)) block_q4_1
{
float d;
float m;
half d;
half m;
uint8_t qs[16]; /* QK4_1 / 2 */
};
@ -47,7 +47,7 @@ struct __attribute__ ((packed)) block_q5_1
struct __attribute__ ((packed)) block_q8_0
{
float d;
half d;
int8_t qs[32]; /* QK8_0 */
};
@ -56,7 +56,7 @@ __kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float*
const uint i = get_global_id(0) / 32; /* QK4_0 */
const uint j = get_local_id(0);
const float d = x[i].d;
const float d = vload_half(0, (__global half*) &x[i].d);
const int x0 = (x[i].qs[j] & 0xf) - 8;
const int x1 = (x[i].qs[j] >> 4) - 8;
@ -69,8 +69,8 @@ __kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float*
const uint i = get_global_id(0) / 32; /* QK4_1 */
const uint j = get_local_id(0);
const float d = x[i].d;
const float m = x[i].m;
const float d = vload_half(0, (__global half*) &x[i].d);
const float m = vload_half(0, (__global half*) &x[i].m);
const int x0 = (x[i].qs[j] & 0xf);
const int x1 = (x[i].qs[j] >> 4);
@ -120,7 +120,7 @@ __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float*
const uint i = get_global_id(0) / 32; /* QK8_0 */
const uint j = get_local_id(0);
const float d = x[i].d;
const float d = vload_half(0, (__global half*) &x[i].d);
y[i*32 + j] = x[i].qs[j]*d;
}
@ -358,13 +358,13 @@ void ggml_cl_sgemm_wrapper(
dequant = true;
kernel = kernel_q4_0;
local = 16;
size_qb = global * (sizeof(float) + local) / 32;
size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
break;
case GGML_TYPE_Q4_1:
dequant = true;
kernel = kernel_q4_1;
local = 16;
size_qb = global * (sizeof(float) * 2 + local) / 32;
size_qb = global * (sizeof(ggml_fp16_t) * 2 + local) / 32;
break;
case GGML_TYPE_Q5_0:
dequant = true;
@ -382,7 +382,7 @@ void ggml_cl_sgemm_wrapper(
dequant = true;
kernel = kernel_q8_0;
local = 32;
size_qb = global * (sizeof(float) + local) / 32;
size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
break;
default:
fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);