k_quants: WIP super-blocks with 64 weights
Q2_K working on CUDA. ~3% slower on GTX-1660, 10% slower on 4080.
This commit is contained in:
parent
5aae4b8d4f
commit
41e46ec1c2
1 changed files with 53 additions and 4 deletions
57
ggml-cuda.cu
57
ggml-cuda.cu
|
@ -364,13 +364,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
|
|||
static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_q2_K * x = (const block_q2_K *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
#if QK_K == 256
|
||||
const int n = tid/32;
|
||||
const int l = tid - 32*n;
|
||||
const int is = 8*n + l/16;
|
||||
|
||||
const block_q2_K * x = (const block_q2_K *) vx;
|
||||
|
||||
const uint8_t q = x[i].qs[32*n + l];
|
||||
float * y = yy + i*QK_K + 128*n;
|
||||
|
||||
|
@ -380,6 +381,16 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
|
|||
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
||||
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
||||
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
||||
#else
|
||||
const int is = tid/16; // 0 or 1
|
||||
const int il = tid%16; // 0...15
|
||||
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||
float * y = yy + i*QK_K + 16*is + il;
|
||||
float dall = x[i].d;
|
||||
float dmin = x[i].dmin;
|
||||
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
||||
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
|
@ -550,6 +561,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
|||
|
||||
const block_q2_K * x = (const block_q2_K *)vx + ib0;
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
#if QK_K == 256
|
||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
|
||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||
|
||||
|
@ -563,8 +577,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
|||
const int s_offset = 8*im;
|
||||
const int y_offset = 128*im + l0;
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
uint32_t aux[4];
|
||||
const uint8_t * d = (const uint8_t *)aux;
|
||||
const uint8_t * m = (const uint8_t *)(aux + 2);
|
||||
|
@ -600,6 +612,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
|||
tmp += dall * sum1 - dmin * sum2;
|
||||
|
||||
}
|
||||
#else
|
||||
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
|
||||
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
|
||||
const int offset = tid * K_QUANTS_PER_ITERATION;
|
||||
|
||||
uint32_t uaux[2];
|
||||
const uint8_t * d = (const uint8_t *)uaux;
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
||||
|
||||
const float * y = yy + i * QK_K + offset;
|
||||
const uint8_t * q = x[i].qs + offset;
|
||||
const uint32_t * s = (const uint32_t *)x[i].scales;
|
||||
|
||||
uaux[0] = s[0] & 0x0f0f0f0f;
|
||||
uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
|
||||
|
||||
const half2 * dh = (const half2 *)&x[i].d;
|
||||
|
||||
const float2 dall = __half22float2(dh[0]);
|
||||
|
||||
float sum1 = 0, sum2 = 0;
|
||||
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
||||
const uint8_t ql = q[l];
|
||||
sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
|
||||
+ y[l+16] * d[1] * ((ql >> 2) & 3)
|
||||
+ y[l+32] * d[2] * ((ql >> 4) & 3)
|
||||
+ y[l+48] * d[3] * ((ql >> 6) & 3);
|
||||
sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
|
||||
}
|
||||
tmp += dall.x * sum1 - dall.y * sum2;
|
||||
}
|
||||
#endif
|
||||
|
||||
// sum up partial sums and write back result
|
||||
__syncthreads();
|
||||
|
@ -1351,7 +1396,11 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
|
|||
|
||||
static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
#if QK_K == 256
|
||||
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
#else
|
||||
dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue