k_quants: WIP super-blocks with 64 weights

Q2_K working on CUDA. ~3% slower on GTX-1660,
10% slower on 4080.
This commit is contained in:
Iwan Kawrakow 2023-06-22 11:39:00 +03:00
parent 5aae4b8d4f
commit 41e46ec1c2

View file

@ -364,13 +364,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
const int i = blockIdx.x;
const block_q2_K * x = (const block_q2_K *) vx;
const int tid = threadIdx.x;
#if QK_K == 256
const int n = tid/32;
const int l = tid - 32*n;
const int is = 8*n + l/16;
const block_q2_K * x = (const block_q2_K *) vx;
const uint8_t q = x[i].qs[32*n + l];
float * y = yy + i*QK_K + 128*n;
@ -380,6 +381,16 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
#else
const int is = tid/16; // 0 or 1
const int il = tid%16; // 0...15
const uint8_t q = x[i].qs[il] >> (2*is);
float * y = yy + i*QK_K + 16*is + il;
float dall = x[i].d;
float dmin = x[i].dmin;
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
#endif
}
@ -550,6 +561,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
const block_q2_K * x = (const block_q2_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
#if QK_K == 256
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
@ -563,8 +577,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
const int s_offset = 8*im;
const int y_offset = 128*im + l0;
float tmp = 0; // partial sum for thread in warp
uint32_t aux[4];
const uint8_t * d = (const uint8_t *)aux;
const uint8_t * m = (const uint8_t *)(aux + 2);
@ -600,6 +612,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
tmp += dall * sum1 - dmin * sum2;
}
#else
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
const int offset = tid * K_QUANTS_PER_ITERATION;
uint32_t uaux[2];
const uint8_t * d = (const uint8_t *)uaux;
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + offset;
const uint8_t * q = x[i].qs + offset;
const uint32_t * s = (const uint32_t *)x[i].scales;
uaux[0] = s[0] & 0x0f0f0f0f;
uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
const half2 * dh = (const half2 *)&x[i].d;
const float2 dall = __half22float2(dh[0]);
float sum1 = 0, sum2 = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
const uint8_t ql = q[l];
sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
+ y[l+16] * d[1] * ((ql >> 2) & 3)
+ y[l+32] * d[2] * ((ql >> 4) & 3)
+ y[l+48] * d[3] * ((ql >> 6) & 3);
sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
}
tmp += dall.x * sum1 - dall.y * sum2;
}
#endif
// sum up partial sums and write back result
__syncthreads();
@ -1351,7 +1396,11 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
#else
dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
#endif
}
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {