k_quants: WIP super-blocks with 64 weights

Q3_K working on CUDA.
This commit is contained in:
Iwan Kawrakow 2023-06-22 12:51:18 +03:00
parent 41e46ec1c2
commit 460dd841b1

View file

@ -125,7 +125,6 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
#define K_SCALE_SIZE 12
#endif
typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K/4]; // quants
@ -135,12 +134,16 @@ typedef struct {
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
typedef struct {
uint8_t hmask[QK_K/8];
uint8_t qs[QK_K/4]; // nibbles / quants
uint8_t scales[3*QK_K/64];
half d;
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
#ifdef GGML_QKK_64
int8_t scales[K_SCALE_SIZE]; // scales, quantized with 8 bits
#else
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
#endif
half d; // super-block scale
} block_q3_K;
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
#ifdef GGML_QKK_64
typedef struct {
@ -396,16 +399,17 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
int r = threadIdx.x/4;
int i = blockIdx.x;
int tid = r/2;
int is0 = r%2;
int l0 = 16*is0 + 4*(threadIdx.x%4);
int n = tid / 4;
int j = tid - 4*n;
const int i = blockIdx.x;
const block_q3_K * x = (const block_q3_K *) vx;
#if QK_K == 256
const int r = threadIdx.x/4;
const int tid = r/2;
const int is0 = r%2;
const int l0 = 16*is0 + 4*(threadIdx.x%4);
const int n = tid / 4;
const int j = tid - 4*n;
uint8_t m = 1 << (4*n + j);
int is = 8*n + 2*j + is0;
int shift = 2*j;
@ -422,6 +426,22 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
const uint8_t * hm = x[i].hmask;
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
#else
const int tid = threadIdx.x;
const int is = tid/16; // 0 or 1
const int il = tid%16; // 0...15
const int im = il/8; // 0...1
const int in = il%8; // 0...7
float * y = yy + i*QK_K + 16*is + il;
const uint8_t q = x[i].qs[il] >> (2*is);
const uint8_t h = x[i].hmask[in] >> (2*is + im);
const float d = (float)x[i].d;
y[ 0] = d * x[i].scales[is+0] * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
y[32] = d * x[i].scales[is+2] * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
#endif
}
@ -660,9 +680,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
@ -671,6 +688,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
const block_q3_K * x = (const block_q3_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
#if QK_K == 256
const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
@ -690,8 +714,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
const uint16_t s_shift = 4*im;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset;
@ -720,6 +742,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
tmp += d * sum;
}
#else
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
const int in = offset/8; // 0 or 1
const int im = offset%8; // 0...7
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + offset;
const uint8_t * q = x[i].qs + offset;
const int8_t * s = x[i].scales;
const float dall = (float)x[i].d;
float sum = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
const uint8_t hl = x[i].hmask[im+l] >> in;
const uint8_t ql = q[l];
sum += y[l+ 0] * dall * s[0] * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
+ y[l+16] * dall * s[1] * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
+ y[l+32] * dall * s[2] * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
+ y[l+48] * dall * s[3] * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
}
tmp += sum;
}
#endif
// sum up partial sums and write back result
__syncthreads();
@ -728,7 +778,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (tid == 0) {
if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
@ -1405,7 +1455,11 @@ static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cu
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
#else
dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
#endif
}
static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {