diff --git a/ggml-opencl.c b/ggml-opencl.c index 361dc07ed..c9e7418d5 100644 --- a/ggml-opencl.c +++ b/ggml-opencl.c @@ -19,14 +19,14 @@ typedef uint uint32_t; struct __attribute__ ((packed)) block_q4_0 { - float d; + half d; uint8_t qs[16]; /* QK4_0 / 2 */ }; struct __attribute__ ((packed)) block_q4_1 { - float d; - float m; + half d; + half m; uint8_t qs[16]; /* QK4_1 / 2 */ }; @@ -47,7 +47,7 @@ struct __attribute__ ((packed)) block_q5_1 struct __attribute__ ((packed)) block_q8_0 { - float d; + half d; int8_t qs[32]; /* QK8_0 */ }; @@ -56,7 +56,7 @@ __kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float* const uint i = get_global_id(0) / 32; /* QK4_0 */ const uint j = get_local_id(0); - const float d = x[i].d; + const float d = vload_half(0, (__global half*) &x[i].d); const int x0 = (x[i].qs[j] & 0xf) - 8; const int x1 = (x[i].qs[j] >> 4) - 8; @@ -69,8 +69,8 @@ __kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* const uint i = get_global_id(0) / 32; /* QK4_1 */ const uint j = get_local_id(0); - const float d = x[i].d; - const float m = x[i].m; + const float d = vload_half(0, (__global half*) &x[i].d); + const float m = vload_half(0, (__global half*) &x[i].m); const int x0 = (x[i].qs[j] & 0xf); const int x1 = (x[i].qs[j] >> 4); @@ -120,7 +120,7 @@ __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* const uint i = get_global_id(0) / 32; /* QK8_0 */ const uint j = get_local_id(0); - const float d = x[i].d; + const float d = vload_half(0, (__global half*) &x[i].d); y[i*32 + j] = x[i].qs[j]*d; } @@ -358,13 +358,13 @@ void ggml_cl_sgemm_wrapper( dequant = true; kernel = kernel_q4_0; local = 16; - size_qb = global * (sizeof(float) + local) / 32; + size_qb = global * (sizeof(ggml_fp16_t) + local) / 32; break; case GGML_TYPE_Q4_1: dequant = true; kernel = kernel_q4_1; local = 16; - size_qb = global * (sizeof(float) * 2 + local) / 32; + size_qb = global * (sizeof(ggml_fp16_t) * 2 + local) / 32; break; case GGML_TYPE_Q5_0: dequant = true; @@ -382,7 +382,7 @@ void ggml_cl_sgemm_wrapper( dequant = true; kernel = kernel_q8_0; local = 32; - size_qb = global * (sizeof(float) + local) / 32; + size_qb = global * (sizeof(ggml_fp16_t) + local) / 32; break; default: fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);