k_quants: WIP super-blocks with 64 weights

Q5_K working on CUDA, and with this CUDA is done.
This commit is contained in:
Iwan Kawrakow 2023-06-22 13:31:54 +03:00
parent 460dd841b1
commit 3bd9ae79d8

View file

@ -161,14 +161,23 @@ typedef struct {
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
#endif
#ifdef GGML_QKK_64
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
half d[2*QK_K/32]; // super-block scales/mins
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
//static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
static_assert(sizeof(block_q5_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
#else
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
#endif
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
@ -445,6 +454,7 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
}
#if QK_K == 256
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
if (j < 4) {
d = q[j] & 63; m = q[j + 4] & 63;
@ -453,6 +463,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
#endif
static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
const block_q4_K * x = (const block_q4_K *) vx;
@ -497,6 +508,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
const int i = blockIdx.x;
#if QK_K == 256
// assume 64 threads - this is very slightly better than the one below
const int tid = threadIdx.x;
const int il = tid/16; // il is in 0...3
@ -523,6 +535,16 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
hm <<= 1;
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
#else
const int tid = threadIdx.x;
const uint8_t q = x[i].qs[tid];
const int im = tid/8; // 0...3
const int in = tid%8; // 0...7
const uint8_t h = x[i].qh[in] >> im;
float * y = yy + i*QK_K + tid;
y[ 0] = (float)x[i].d[0] * ((q & 0xF) + ((h >> 0) & 1 ? 16 : 0)) - (float)x[i].d[1];
y[32] = (float)x[i].d[2] * ((q >> 4) + ((h >> 4) & 1 ? 16 : 0)) - (float)x[i].d[3];
#endif
}
static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
@ -855,14 +877,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
const half2 * d = (const half2 *)x[i].d;
float2 df1 = __half22float2(d[0]);
float2 df2 = __half22float2(d[1]);
//float2 sum = {0.f, 0.f};
//float smin = 0;
//for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
// sum.x += y[j+ 0] * (q[j+ 0] & 0xF) + y[j+16] * (q[j+16] & 0xF);
// sum.y += y[j+32] * (q[j+ 0] >> 4) + y[j+48] * (q[j+16] >> 4);
// smin += df1.y * (y[j] + y[j+16]) + df2.y * (y[j+32] + y[j+48]);
//}
//tmp += df1.x * sum.x + df2.x * sum.y - smin;
float sum = 0.f;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
sum += y[j+ 0] * (df1.x * (q[j+ 0] & 0xF) - df1.y)
@ -889,15 +903,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
//const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int row = blockIdx.x;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q5_K * x = (const block_q5_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
#if QK_K == 256
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int tid = threadIdx.x/2; // 0...15
const int ix = threadIdx.x%2;
@ -918,10 +936,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
const block_q5_K * x = (const block_q5_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 2) {
const uint8_t * ql1 = x[i].qs + q_offset;
@ -954,9 +968,33 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
}
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
}
#else
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
const int step = tid * K_QUANTS_PER_ITERATION;
const int im = step/8;
const int in = step%8;
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const uint8_t * q = x[i].qs + step;
const float * y = yy + i*QK_K + step;
const half2 * d = (const half2 *)x[i].d;
float2 df1 = __half22float2(d[0]);
float2 df2 = __half22float2(d[1]);
float sum = 0.f;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
const uint8_t h = x[i].qh[in+j] >> im;
sum += y[j+ 0] * (df1.x * ((q[j+ 0] & 0xF) + (((h >> 0) & 1) << 4)) - df1.y)
+ y[j+16] * (df1.x * ((q[j+16] & 0xF) + (((h >> 2) & 1) << 4)) - df1.y)
+ y[j+32] * (df2.x * ((q[j+ 0] >> 4) + (((h >> 4) & 1) << 4)) - df2.y)
+ y[j+48] * (df2.x * ((q[j+16] >> 4) + (((h >> 6) & 1) << 4)) - df2.y);
}
tmp += sum;
}
#endif
// sum up partial sums and write back result
__syncthreads();
#pragma unroll
@ -964,7 +1002,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (tid == 0) {
if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
@ -1469,7 +1507,11 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu
static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
#else
dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
#endif
}
static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {