remove qk as well

This commit is contained in:
Henri Vasserman 2023-05-14 13:22:39 +03:00 committed by GitHub
parent 9074e353dd
commit 394dabbc1a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -60,8 +60,8 @@ __kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float*
const int x0 = (x[i].qs[j] & 0xf) - 8; const int x0 = (x[i].qs[j] & 0xf) - 8;
const int x1 = (x[i].qs[j] >> 4) - 8; const int x1 = (x[i].qs[j] >> 4) - 8;
y[i*qk + j + 0 ] = x0*d; y[i*32 + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d; y[i*32 + j + 16] = x1*d;
} }
__kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* y) { __kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* y) {
@ -74,8 +74,8 @@ __kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float*
const int x0 = (x[i].qs[j] & 0xf); const int x0 = (x[i].qs[j] & 0xf);
const int x1 = (x[i].qs[j] >> 4); const int x1 = (x[i].qs[j] >> 4);
y[i*qk + j + 0 ] = x0*d + m; y[i*32 + j + 0 ] = x0*d + m;
y[i*qk + j + qk/2] = x1*d + m; y[i*32 + j + 16] = x1*d + m;
} }
__kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* y) { __kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* y) {
@ -92,8 +92,8 @@ __kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float*
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16; const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
y[i*qk + j + 0 ] = x0*d; y[i*32 + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d; y[i*32 + j + 16] = x1*d;
} }
__kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* y) { __kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* y) {
@ -111,8 +111,8 @@ __kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float*
const int x0 = (x[i].qs[j] & 0xf) | xh_0; const int x0 = (x[i].qs[j] & 0xf) | xh_0;
const int x1 = (x[i].qs[j] >> 4) | xh_1; const int x1 = (x[i].qs[j] >> 4) | xh_1;
y[i*qk + j + 0 ] = x0*d + m; y[i*32 + j + 0 ] = x0*d + m;
y[i*qk + j + qk/2] = x1*d + m; y[i*32 + j + 16] = x1*d + m;
} }
__kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* y) { __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* y) {
@ -120,7 +120,7 @@ __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float*
const uint j = get_local_id(0); const uint j = get_local_id(0);
const float d = x[i].d; const float d = x[i].d;
y[i*qk + j] = x[i].qs[j]*d; y[i*32 + j] = x[i].qs[j]*d;
} }
); );