k_quants: WIP super-blocks with 64 weights
Q6_K working on CUDA. Cannot make it run quite as gast as with super-blocks with 256 weigths: 8% slower on 4080, 20% slower on the 1660 (but there we fit 1 less layer on the GPU because pf the larger model size), so some fraction of these 20% is due to that,
This commit is contained in:
parent
bcf8c5c384
commit
c6c35366bf
2 changed files with 75 additions and 11 deletions
|
@ -226,6 +226,14 @@ if (LLAMA_BLAS)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_K_QUANTS)
|
||||
set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h)
|
||||
add_compile_definitions(GGML_USE_K_QUANTS)
|
||||
if (LLAMA_QKK_64)
|
||||
add_compile_definitions(GGML_QKK_64)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_CUBLAS)
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
|
||||
|
@ -290,14 +298,6 @@ if (LLAMA_METAL)
|
|||
)
|
||||
endif()
|
||||
|
||||
if (LLAMA_K_QUANTS)
|
||||
set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h)
|
||||
add_compile_definitions(GGML_USE_K_QUANTS)
|
||||
if (LLAMA_QKK_64)
|
||||
add_compile_definitions(GGML_QKK_64)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_CLBLAST)
|
||||
find_package(CLBlast)
|
||||
if (CLBlast_FOUND)
|
||||
|
|
70
ggml-cuda.cu
70
ggml-cuda.cu
|
@ -117,7 +117,14 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
|
|||
|
||||
//================================= k-quants
|
||||
|
||||
#ifdef GGML_QKK_64
|
||||
#define QK_K 64
|
||||
#define K_SCALE_SIZE 4
|
||||
#else
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
#endif
|
||||
|
||||
|
||||
typedef struct {
|
||||
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
||||
|
@ -133,7 +140,7 @@ typedef struct {
|
|||
uint8_t scales[3*QK_K/64];
|
||||
half d;
|
||||
} block_q3_K;
|
||||
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
|
||||
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
|
||||
|
||||
typedef struct {
|
||||
half d; // super-block scale for quantized scales
|
||||
|
@ -141,7 +148,7 @@ typedef struct {
|
|||
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
|
||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||
} block_q4_K;
|
||||
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
|
||||
//static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
|
||||
|
||||
typedef struct {
|
||||
half d; // super-block scale for quantized scales
|
||||
|
@ -150,7 +157,7 @@ typedef struct {
|
|||
uint8_t qh[QK_K/8]; // quants, high bit
|
||||
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
||||
} block_q5_K;
|
||||
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
|
||||
//static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
|
||||
|
||||
typedef struct {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
|
@ -482,6 +489,7 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
|
|||
const block_q6_K * x = (const block_q6_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
#if QK_K == 256
|
||||
|
||||
// assume 64 threads - this is very slightly better than the one below
|
||||
const int tid = threadIdx.x;
|
||||
|
@ -501,6 +509,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
|
|||
y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
|
||||
y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
||||
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
||||
#else
|
||||
|
||||
// assume 32 threads
|
||||
const int tid = threadIdx.x;
|
||||
const int ip = tid/16; // 0 or 1
|
||||
const int il = tid - 16*ip; // 0...15
|
||||
|
||||
float * y = yy + i*QK_K + 16*ip + il;
|
||||
|
||||
const float d = x[i].d;
|
||||
|
||||
const uint8_t ql = x[i].ql[16*ip + il];
|
||||
const uint8_t qh = x[i].qh[il] >> (2*ip);
|
||||
const int8_t * sc = x[i].scales;
|
||||
|
||||
y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
|
||||
y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
||||
#endif
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
||||
|
@ -820,6 +846,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
|
|||
|
||||
const block_q6_K * x = (const block_q6_K *)vx + ib0;
|
||||
|
||||
#if QK_K == 256
|
||||
|
||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
||||
|
||||
|
@ -874,6 +902,38 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
|
|||
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
const int tid = threadIdx.x/4; // 0...7
|
||||
const int ix = threadIdx.x%4; // 0...3
|
||||
|
||||
const int step = tid;
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += 4) {
|
||||
|
||||
const float * y = yy + i * QK_K + step;
|
||||
const uint8_t * ql = x[i].ql + step;
|
||||
const uint8_t * qh = x[i].qh + step;
|
||||
const int8_t * s = x[i].scales;
|
||||
|
||||
const float d = x[i+0].d;
|
||||
|
||||
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[0] & 0x03) << 4)) - 32)
|
||||
+ y[ 8] * s[0] * d * ((int8_t)((ql[ 8] & 0xF) | ((qh[8] & 0x03) << 4)) - 32)
|
||||
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[0] & 0x0c) << 2)) - 32)
|
||||
+ y[24] * s[1] * d * ((int8_t)((ql[24] & 0xF) | ((qh[8] & 0x0c) << 2)) - 32)
|
||||
+ y[32] * s[2] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[0] & 0x30) >> 0)) - 32)
|
||||
+ y[40] * s[2] * d * ((int8_t)((ql[ 8] >> 4) | ((qh[8] & 0x30) >> 0)) - 32)
|
||||
+ y[48] * s[3] * d * ((int8_t)((ql[16] >> 4) | ((qh[0] & 0xc0) >> 2)) - 32)
|
||||
+ y[56] * s[3] * d * ((int8_t)((ql[24] >> 4) | ((qh[8] & 0xc0) >> 2)) - 32);
|
||||
tmp += sum;
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// sum up partial sums and write back result
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
|
@ -1272,7 +1332,11 @@ static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cu
|
|||
|
||||
static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
#if QK_K == 256
|
||||
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
#else
|
||||
dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue