k_quants: WIP super-blocks with 64 weights
Q4_K working on CUDA. ~10% slower on GTX-1660, 16% slower on 4080.
This commit is contained in:
parent
c6c35366bf
commit
5aae4b8d4f
1 changed files with 65 additions and 26 deletions
91
ggml-cuda.cu
91
ggml-cuda.cu
|
@ -142,13 +142,21 @@ typedef struct {
|
||||||
} block_q3_K;
|
} block_q3_K;
|
||||||
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
|
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
|
||||||
|
|
||||||
|
#ifdef GGML_QKK_64
|
||||||
|
typedef struct {
|
||||||
|
half d[2*QK_K/32]; // super-block scales/mins
|
||||||
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||||
|
} block_q4_K;
|
||||||
|
static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding");
|
||||||
|
#else
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // super-block scale for quantized scales
|
half d; // super-block scale for quantized scales
|
||||||
half dmin; // super-block scale for quantized mins
|
half dmin; // super-block scale for quantized mins
|
||||||
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
|
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
|
||||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||||
} block_q4_K;
|
} block_q4_K;
|
||||||
//static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
|
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
|
||||||
|
#endif
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // super-block scale for quantized scales
|
half d; // super-block scale for quantized scales
|
||||||
|
@ -420,13 +428,7 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
//// assume 64 threads - this is very slightly better than the one below
|
#if QK_K == 256
|
||||||
//const int tid = threadIdx.x;
|
|
||||||
//const int il = tid/16;
|
|
||||||
//const int ir = tid%16;
|
|
||||||
//const int is = 2*il;
|
|
||||||
//const int n = 2;
|
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const int il = tid/8;
|
const int il = tid/8;
|
||||||
|
@ -450,6 +452,13 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
|
||||||
y[l + 0] = d1 * (q[l] & 0xF) - m1;
|
y[l + 0] = d1 * (q[l] & 0xF) - m1;
|
||||||
y[l +32] = d2 * (q[l] >> 4) - m2;
|
y[l +32] = d2 * (q[l] >> 4) - m2;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const uint8_t * q = x[i].qs;
|
||||||
|
float * y = yy + i*QK_K;
|
||||||
|
y[tid+ 0] = (float)x[i].d[0] * (q[tid] & 0xF) - (float)x[i].d[1];
|
||||||
|
y[tid+32] = (float)x[i].d[2] * (q[tid] >> 4) - (float)x[i].d[3];
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
|
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
|
||||||
|
@ -681,10 +690,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
||||||
|
|
||||||
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
||||||
|
|
||||||
const uint16_t kmask1 = 0x3f3f;
|
|
||||||
const uint16_t kmask2 = 0x0f0f;
|
|
||||||
const uint16_t kmask3 = 0xc0c0;
|
|
||||||
|
|
||||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||||
if (row > nrows) return;
|
if (row > nrows) return;
|
||||||
const int num_blocks_per_row = ncols / QK_K;
|
const int num_blocks_per_row = ncols / QK_K;
|
||||||
|
@ -693,6 +698,13 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
||||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||||
|
|
||||||
|
const block_q4_K * x = (const block_q4_K *)vx + ib0;
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
|
const uint16_t kmask1 = 0x3f3f;
|
||||||
|
const uint16_t kmask2 = 0x0f0f;
|
||||||
|
const uint16_t kmask3 = 0xc0c0;
|
||||||
|
|
||||||
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
|
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
|
||||||
|
|
||||||
const int il = tid/step; // 0...3
|
const int il = tid/step; // 0...3
|
||||||
|
@ -709,8 +721,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
||||||
uint16_t aux[4];
|
uint16_t aux[4];
|
||||||
const uint8_t * sc = (const uint8_t *)aux;
|
const uint8_t * sc = (const uint8_t *)aux;
|
||||||
|
|
||||||
const block_q4_K * x = (const block_q4_K *)vx + ib0;
|
|
||||||
|
|
||||||
float tmp = 0; // partial sum for thread in warp
|
float tmp = 0; // partial sum for thread in warp
|
||||||
|
|
||||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||||
|
@ -739,6 +749,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
||||||
tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
|
tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
const int step = tid * K_QUANTS_PER_ITERATION;
|
||||||
|
|
||||||
|
float tmp = 0;
|
||||||
|
|
||||||
|
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
||||||
|
const uint8_t * q = x[i].qs + step;
|
||||||
|
const float * y = yy + i*QK_K + step;
|
||||||
|
const half2 * d = (const half2 *)x[i].d;
|
||||||
|
float2 df1 = __half22float2(d[0]);
|
||||||
|
float2 df2 = __half22float2(d[1]);
|
||||||
|
//float2 sum = {0.f, 0.f};
|
||||||
|
//float smin = 0;
|
||||||
|
//for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
||||||
|
// sum.x += y[j+ 0] * (q[j+ 0] & 0xF) + y[j+16] * (q[j+16] & 0xF);
|
||||||
|
// sum.y += y[j+32] * (q[j+ 0] >> 4) + y[j+48] * (q[j+16] >> 4);
|
||||||
|
// smin += df1.y * (y[j] + y[j+16]) + df2.y * (y[j+32] + y[j+48]);
|
||||||
|
//}
|
||||||
|
//tmp += df1.x * sum.x + df2.x * sum.y - smin;
|
||||||
|
float sum = 0.f;
|
||||||
|
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
||||||
|
sum += y[j+ 0] * (df1.x * (q[j+ 0] & 0xF) - df1.y)
|
||||||
|
+ y[j+16] * (df1.x * (q[j+16] & 0xF) - df1.y)
|
||||||
|
+ y[j+32] * (df2.x * (q[j+ 0] >> 4) - df2.y)
|
||||||
|
+ y[j+48] * (df2.x * (q[j+16] >> 4) - df2.y);
|
||||||
|
}
|
||||||
|
tmp += sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
@ -904,14 +944,14 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
const int tid = threadIdx.x/4; // 0...7
|
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
|
||||||
const int ix = threadIdx.x%4; // 0...3
|
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
|
||||||
|
|
||||||
const int step = tid;
|
const int step = tid * K_QUANTS_PER_ITERATION;
|
||||||
|
|
||||||
float tmp = 0; // partial sum for thread in warp
|
float tmp = 0; // partial sum for thread in warp
|
||||||
|
|
||||||
for (int i = ix; i < num_blocks_per_row; i += 4) {
|
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
||||||
|
|
||||||
const float * y = yy + i * QK_K + step;
|
const float * y = yy + i * QK_K + step;
|
||||||
const uint8_t * ql = x[i].ql + step;
|
const uint8_t * ql = x[i].ql + step;
|
||||||
|
@ -920,14 +960,13 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
|
||||||
|
|
||||||
const float d = x[i+0].d;
|
const float d = x[i+0].d;
|
||||||
|
|
||||||
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[0] & 0x03) << 4)) - 32)
|
float sum = 0;
|
||||||
+ y[ 8] * s[0] * d * ((int8_t)((ql[ 8] & 0xF) | ((qh[8] & 0x03) << 4)) - 32)
|
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
||||||
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[0] & 0x0c) << 2)) - 32)
|
sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
|
||||||
+ y[24] * s[1] * d * ((int8_t)((ql[24] & 0xF) | ((qh[8] & 0x0c) << 2)) - 32)
|
+ y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
|
||||||
+ y[32] * s[2] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[0] & 0x30) >> 0)) - 32)
|
+ y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
|
||||||
+ y[40] * s[2] * d * ((int8_t)((ql[ 8] >> 4) | ((qh[8] & 0x30) >> 0)) - 32)
|
+ y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
|
||||||
+ y[48] * s[3] * d * ((int8_t)((ql[16] >> 4) | ((qh[0] & 0xc0) >> 2)) - 32)
|
}
|
||||||
+ y[56] * s[3] * d * ((int8_t)((ql[24] >> 4) | ((qh[8] & 0xc0) >> 2)) - 32);
|
|
||||||
tmp += sum;
|
tmp += sum;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue