diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 486bc866d..ec66c884a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -161,14 +161,23 @@ typedef struct { static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); #endif +#ifdef GGML_QKK_64 typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + half d[2*QK_K/32]; // super-block scales/mins uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; -//static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#else +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif typedef struct { uint8_t ql[QK_K/2]; // quants, lower 4 bits @@ -445,6 +454,7 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { } +#if QK_K == 256 static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { d = q[j] & 63; m = q[j + 4] & 63; @@ -453,6 +463,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); } } +#endif static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { const block_q4_K * x = (const block_q4_K *) vx; @@ -497,6 +508,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { const int i = blockIdx.x; +#if QK_K == 256 // assume 64 threads - this is very slightly better than the one below const int tid = threadIdx.x; const int il = tid/16; // il is in 0...3 @@ -523,6 +535,16 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { hm <<= 1; y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +#else + const int tid = threadIdx.x; + const uint8_t q = x[i].qs[tid]; + const int im = tid/8; // 0...3 + const int in = tid%8; // 0...7 + const uint8_t h = x[i].qh[in] >> im; + float * y = yy + i*QK_K + tid; + y[ 0] = (float)x[i].d[0] * ((q & 0xF) + ((h >> 0) & 1 ? 16 : 0)) - (float)x[i].d[1]; + y[32] = (float)x[i].d[2] * ((q >> 4) + ((h >> 4) & 1 ? 16 : 0)) - (float)x[i].d[3]; +#endif } static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { @@ -855,14 +877,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float const half2 * d = (const half2 *)x[i].d; float2 df1 = __half22float2(d[0]); float2 df2 = __half22float2(d[1]); - //float2 sum = {0.f, 0.f}; - //float smin = 0; - //for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - // sum.x += y[j+ 0] * (q[j+ 0] & 0xF) + y[j+16] * (q[j+16] & 0xF); - // sum.y += y[j+32] * (q[j+ 0] >> 4) + y[j+48] * (q[j+16] >> 4); - // smin += df1.y * (y[j] + y[j+16]) + df2.y * (y[j+32] + y[j+48]); - //} - //tmp += df1.x * sum.x + df2.x * sum.y - smin; float sum = 0.f; for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { sum += y[j+ 0] * (df1.x * (q[j+ 0] & 0xF) - df1.y) @@ -889,15 +903,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - //const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x; const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; + const block_q5_K * x = (const block_q5_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = threadIdx.x/2; // 0...15 const int ix = threadIdx.x%2; @@ -918,10 +936,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; - const block_q5_K * x = (const block_q5_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 2) { const uint8_t * ql1 = x[i].qs + q_offset; @@ -954,9 +968,33 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; } tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; - } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + const int step = tid * K_QUANTS_PER_ITERATION; + const int im = step/8; + const int in = step%8; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const float * y = yy + i*QK_K + step; + const half2 * d = (const half2 *)x[i].d; + float2 df1 = __half22float2(d[0]); + float2 df2 = __half22float2(d[1]); + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + const uint8_t h = x[i].qh[in+j] >> im; + sum += y[j+ 0] * (df1.x * ((q[j+ 0] & 0xF) + (((h >> 0) & 1) << 4)) - df1.y) + + y[j+16] * (df1.x * ((q[j+16] & 0xF) + (((h >> 2) & 1) << 4)) - df1.y) + + y[j+32] * (df2.x * ((q[j+ 0] >> 4) + (((h >> 4) & 1) << 4)) - df2.y) + + y[j+48] * (df2.x * ((q[j+16] >> 4) + (((h >> 6) & 1) << 4)) - df2.y); + } + tmp += sum; + } +#endif + // sum up partial sums and write back result __syncthreads(); #pragma unroll @@ -964,7 +1002,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } - if (tid == 0) { + if (threadIdx.x == 0) { dst[row] = tmp; } } @@ -1469,7 +1507,11 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q5_K<<>>(vx, y); +#else + dequantize_block_q5_K<<>>(vx, y); +#endif } static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {