k_quants: WIP super-blocks with 64 weights
Q3_K working on CUDA.
This commit is contained in:
parent
41e46ec1c2
commit
460dd841b1
1 changed files with 74 additions and 20 deletions
94
ggml-cuda.cu
94
ggml-cuda.cu
|
@ -125,7 +125,6 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
|
||||||
#define K_SCALE_SIZE 12
|
#define K_SCALE_SIZE 12
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
||||||
uint8_t qs[QK_K/4]; // quants
|
uint8_t qs[QK_K/4]; // quants
|
||||||
|
@ -135,12 +134,16 @@ typedef struct {
|
||||||
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
|
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
uint8_t hmask[QK_K/8];
|
uint8_t hmask[QK_K/8]; // quants - high bit
|
||||||
uint8_t qs[QK_K/4]; // nibbles / quants
|
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
||||||
uint8_t scales[3*QK_K/64];
|
#ifdef GGML_QKK_64
|
||||||
half d;
|
int8_t scales[K_SCALE_SIZE]; // scales, quantized with 8 bits
|
||||||
|
#else
|
||||||
|
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
||||||
|
#endif
|
||||||
|
half d; // super-block scale
|
||||||
} block_q3_K;
|
} block_q3_K;
|
||||||
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
|
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
|
||||||
|
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -396,16 +399,17 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
|
||||||
|
|
||||||
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
||||||
|
|
||||||
int r = threadIdx.x/4;
|
const int i = blockIdx.x;
|
||||||
int i = blockIdx.x;
|
|
||||||
int tid = r/2;
|
|
||||||
int is0 = r%2;
|
|
||||||
int l0 = 16*is0 + 4*(threadIdx.x%4);
|
|
||||||
int n = tid / 4;
|
|
||||||
int j = tid - 4*n;
|
|
||||||
|
|
||||||
const block_q3_K * x = (const block_q3_K *) vx;
|
const block_q3_K * x = (const block_q3_K *) vx;
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
|
const int r = threadIdx.x/4;
|
||||||
|
const int tid = r/2;
|
||||||
|
const int is0 = r%2;
|
||||||
|
const int l0 = 16*is0 + 4*(threadIdx.x%4);
|
||||||
|
const int n = tid / 4;
|
||||||
|
const int j = tid - 4*n;
|
||||||
|
|
||||||
uint8_t m = 1 << (4*n + j);
|
uint8_t m = 1 << (4*n + j);
|
||||||
int is = 8*n + 2*j + is0;
|
int is = 8*n + 2*j + is0;
|
||||||
int shift = 2*j;
|
int shift = 2*j;
|
||||||
|
@ -422,6 +426,22 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
||||||
const uint8_t * hm = x[i].hmask;
|
const uint8_t * hm = x[i].hmask;
|
||||||
|
|
||||||
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
||||||
|
#else
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int is = tid/16; // 0 or 1
|
||||||
|
const int il = tid%16; // 0...15
|
||||||
|
const int im = il/8; // 0...1
|
||||||
|
const int in = il%8; // 0...7
|
||||||
|
|
||||||
|
float * y = yy + i*QK_K + 16*is + il;
|
||||||
|
|
||||||
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||||
|
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
||||||
|
const float d = (float)x[i].d;
|
||||||
|
|
||||||
|
y[ 0] = d * x[i].scales[is+0] * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
|
||||||
|
y[32] = d * x[i].scales[is+2] * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -660,9 +680,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
||||||
|
|
||||||
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
||||||
|
|
||||||
const uint16_t kmask1 = 0x0303;
|
|
||||||
const uint16_t kmask2 = 0x0f0f;
|
|
||||||
|
|
||||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||||
if (row > nrows) return;
|
if (row > nrows) return;
|
||||||
|
|
||||||
|
@ -671,6 +688,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
||||||
|
|
||||||
const block_q3_K * x = (const block_q3_K *)vx + ib0;
|
const block_q3_K * x = (const block_q3_K *)vx + ib0;
|
||||||
|
|
||||||
|
float tmp = 0; // partial sum for thread in warp
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
|
|
||||||
|
const uint16_t kmask1 = 0x0303;
|
||||||
|
const uint16_t kmask2 = 0x0f0f;
|
||||||
|
|
||||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||||
|
|
||||||
|
@ -690,8 +714,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
||||||
|
|
||||||
const uint16_t s_shift = 4*im;
|
const uint16_t s_shift = 4*im;
|
||||||
|
|
||||||
float tmp = 0; // partial sum for thread in warp
|
|
||||||
|
|
||||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||||
|
|
||||||
const float * y = yy + i * QK_K + y_offset;
|
const float * y = yy + i * QK_K + y_offset;
|
||||||
|
@ -720,6 +742,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
||||||
tmp += d * sum;
|
tmp += d * sum;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
|
||||||
|
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
|
||||||
|
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
|
||||||
|
const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
|
||||||
|
const int in = offset/8; // 0 or 1
|
||||||
|
const int im = offset%8; // 0...7
|
||||||
|
|
||||||
|
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
||||||
|
|
||||||
|
const float * y = yy + i * QK_K + offset;
|
||||||
|
const uint8_t * q = x[i].qs + offset;
|
||||||
|
const int8_t * s = x[i].scales;
|
||||||
|
|
||||||
|
const float dall = (float)x[i].d;
|
||||||
|
|
||||||
|
float sum = 0;
|
||||||
|
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
||||||
|
const uint8_t hl = x[i].hmask[im+l] >> in;
|
||||||
|
const uint8_t ql = q[l];
|
||||||
|
sum += y[l+ 0] * dall * s[0] * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
|
||||||
|
+ y[l+16] * dall * s[1] * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
|
||||||
|
+ y[l+32] * dall * s[2] * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
|
||||||
|
+ y[l+48] * dall * s[3] * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
|
||||||
|
}
|
||||||
|
tmp += sum;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
@ -728,7 +778,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tid == 0) {
|
if (threadIdx.x == 0) {
|
||||||
dst[row] = tmp;
|
dst[row] = tmp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1405,7 +1455,11 @@ static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cu
|
||||||
|
|
||||||
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
#if QK_K == 256
|
||||||
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||||
|
#else
|
||||||
|
dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue