k_quants: WIP super-blocks with 64 weights

Q4_K working on CUDA. ~10% slower on GTX-1660,
16% slower on 4080.
This commit is contained in:
Iwan Kawrakow 2023-06-22 09:40:33 +03:00
parent c6c35366bf
commit 5aae4b8d4f

View file

@ -142,13 +142,21 @@ typedef struct {
} block_q3_K;
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
#ifdef GGML_QKK_64
typedef struct {
half d[2*QK_K/32]; // super-block scales/mins
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding");
#else
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
//static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
#endif
typedef struct {
half d; // super-block scale for quantized scales
@ -420,13 +428,7 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
const int i = blockIdx.x;
//// assume 64 threads - this is very slightly better than the one below
//const int tid = threadIdx.x;
//const int il = tid/16;
//const int ir = tid%16;
//const int is = 2*il;
//const int n = 2;
#if QK_K == 256
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
@ -450,6 +452,13 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
y[l + 0] = d1 * (q[l] & 0xF) - m1;
y[l +32] = d2 * (q[l] >> 4) - m2;
}
#else
const int tid = threadIdx.x;
const uint8_t * q = x[i].qs;
float * y = yy + i*QK_K;
y[tid+ 0] = (float)x[i].d[0] * (q[tid] & 0xF) - (float)x[i].d[1];
y[tid+32] = (float)x[i].d[2] * (q[tid] >> 4) - (float)x[i].d[3];
#endif
}
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
@ -681,10 +690,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
@ -693,6 +698,13 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const block_q4_K * x = (const block_q4_K *)vx + ib0;
#if QK_K == 256
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
const int il = tid/step; // 0...3
@ -709,8 +721,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
const block_q4_K * x = (const block_q4_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
@ -739,6 +749,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
}
#else
const int step = tid * K_QUANTS_PER_ITERATION;
float tmp = 0;
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const uint8_t * q = x[i].qs + step;
const float * y = yy + i*QK_K + step;
const half2 * d = (const half2 *)x[i].d;
float2 df1 = __half22float2(d[0]);
float2 df2 = __half22float2(d[1]);
//float2 sum = {0.f, 0.f};
//float smin = 0;
//for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
// sum.x += y[j+ 0] * (q[j+ 0] & 0xF) + y[j+16] * (q[j+16] & 0xF);
// sum.y += y[j+32] * (q[j+ 0] >> 4) + y[j+48] * (q[j+16] >> 4);
// smin += df1.y * (y[j] + y[j+16]) + df2.y * (y[j+32] + y[j+48]);
//}
//tmp += df1.x * sum.x + df2.x * sum.y - smin;
float sum = 0.f;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
sum += y[j+ 0] * (df1.x * (q[j+ 0] & 0xF) - df1.y)
+ y[j+16] * (df1.x * (q[j+16] & 0xF) - df1.y)
+ y[j+32] * (df2.x * (q[j+ 0] >> 4) - df2.y)
+ y[j+48] * (df2.x * (q[j+16] >> 4) - df2.y);
}
tmp += sum;
}
#endif
// sum up partial sums and write back result
__syncthreads();
@ -904,14 +944,14 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
#else
const int tid = threadIdx.x/4; // 0...7
const int ix = threadIdx.x%4; // 0...3
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
const int step = tid;
const int step = tid * K_QUANTS_PER_ITERATION;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 4) {
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + step;
const uint8_t * ql = x[i].ql + step;
@ -920,14 +960,13 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
const float d = x[i+0].d;
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[0] & 0x03) << 4)) - 32)
+ y[ 8] * s[0] * d * ((int8_t)((ql[ 8] & 0xF) | ((qh[8] & 0x03) << 4)) - 32)
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[0] & 0x0c) << 2)) - 32)
+ y[24] * s[1] * d * ((int8_t)((ql[24] & 0xF) | ((qh[8] & 0x0c) << 2)) - 32)
+ y[32] * s[2] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[0] & 0x30) >> 0)) - 32)
+ y[40] * s[2] * d * ((int8_t)((ql[ 8] >> 4) | ((qh[8] & 0x30) >> 0)) - 32)
+ y[48] * s[3] * d * ((int8_t)((ql[16] >> 4) | ((qh[0] & 0xc0) >> 2)) - 32)
+ y[56] * s[3] * d * ((int8_t)((ql[24] >> 4) | ((qh[8] & 0xc0) >> 2)) - 32);
float sum = 0;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
+ y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
+ y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
+ y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
}
tmp += sum;
}