fix half2 decomposition

This commit is contained in:
ardfork 2023-07-31 20:35:00 +03:00 committed by Henri Vasserman
parent c1cb70d64d
commit d91456aaf1
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986

View file

@ -472,8 +472,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx; const block_q4_1 * x = (const block_q4_1 *) vx;
const dfloat d = x[ib].dm.x; const dfloat d = __low2half(x[ib].dm);
const dfloat m = x[ib].dm.y; const dfloat m = __high2half(x[ib].dm);
const int vui = x[ib].qs[iqs]; const int vui = x[ib].qs[iqs];
@ -515,8 +515,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx; const block_q5_1 * x = (const block_q5_1 *) vx;
const dfloat d = x[ib].dm.x; const dfloat d = __low2half(x[ib].dm);
const dfloat m = x[ib].dm.y; const dfloat m = __high2half(x[ib].dm);
uint32_t qh; uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh)); memcpy(&qh, x[ib].qh, sizeof(qh));
@ -568,8 +568,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
const uint8_t q = x[i].qs[32*n + l]; const uint8_t q = x[i].qs[32*n + l];
float * y = yy + i*QK_K + 128*n; float * y = yy + i*QK_K + 128*n;
float dall = x[i].dm.x; float dall = __low2half(x[i].dm);
float dmin = x[i].dm.y; float dmin = __high2half(x[i].dm);
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
@ -579,8 +579,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
const int il = tid%16; // 0...15 const int il = tid%16; // 0...15
const uint8_t q = x[i].qs[il] >> (2*is); const uint8_t q = x[i].qs[il] >> (2*is);
float * y = yy + i*QK_K + 16*is + il; float * y = yy + i*QK_K + 16*is + il;
float dall = x[i].dm.x; float dall = __low2half(x[i].dm);
float dmin = x[i].dm.y; float dmin = __high2half(x[i].dm);
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
#endif #endif
@ -666,8 +666,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
float * y = yy + i*QK_K + 64*il + n*ir; float * y = yy + i*QK_K + 64*il + n*ir;
const float dall = x[i].dm.x; const float dall = __low2half(x[i].dm);
const float dmin = x[i].dm.y; const float dmin = __high2half(x[i].dm);
const uint8_t * q = x[i].qs + 32*il + n*ir; const uint8_t * q = x[i].qs + 32*il + n*ir;
@ -705,8 +705,8 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
float * y = yy + i*QK_K + 64*il + 2*ir; float * y = yy + i*QK_K + 64*il + 2*ir;
const float dall = x[i].dm.x; const float dall = __low2half(x[i].dm);
const float dmin = x[i].dm.y; const float dmin = __high2half(x[i].dm);
const uint8_t * ql = x[i].qs + 32*il + 2*ir; const uint8_t * ql = x[i].qs + 32*il + 2*ir;
const uint8_t * qh = x[i].qh + 2*ir; const uint8_t * qh = x[i].qh + 2*ir;
@ -818,8 +818,8 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset; const uint8_t * q = x[i].qs + q_offset;
const float dall = x[i].dm.x; const float dall = __low2half(x[i].dm);
const float dmin = x[i].dm.y; const float dmin = __high2half(x[i].dm);
const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
aux[0] = a[0] & 0x0f0f0f0f; aux[0] = a[0] & 0x0f0f0f0f;
@ -1039,8 +1039,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
const float * y1 = yy + i*QK_K + y_offset; const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128; const float * y2 = y1 + 128;
const float dall = x[i].dm.x; const float dall = __low2half(x[i].dm);
const float dmin = x[i].dm.y; const float dmin = __high2half(x[i].dm);
const uint16_t * a = (const uint16_t *)x[i].scales; const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1; aux[0] = a[im+0] & kmask1;
@ -1172,8 +1172,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
const float * y1 = yy + i*QK_K + y_offset; const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128; const float * y2 = y1 + 128;
const float dall = x[i].dm.x; const float dall = __low2half(x[i].dm);
const float dmin = x[i].dm.y; const float dmin = __high2half(x[i].dm);
const uint16_t * a = (const uint16_t *)x[i].scales; const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1; aux[0] = a[im+0] & kmask1;