diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2e86226fb..a5fe7e572 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -364,13 +364,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { const int i = blockIdx.x; + const block_q2_K * x = (const block_q2_K *) vx; + const int tid = threadIdx.x; +#if QK_K == 256 const int n = tid/32; const int l = tid - 32*n; const int is = 8*n + l/16; - const block_q2_K * x = (const block_q2_K *) vx; - const uint8_t q = x[i].qs[32*n + l]; float * y = yy + i*QK_K + 128*n; @@ -380,6 +381,16 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); +#else + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const uint8_t q = x[i].qs[il] >> (2*is); + float * y = yy + i*QK_K + 16*is + il; + float dall = x[i].d; + float dmin = x[i].dmin; + y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); +#endif } @@ -550,6 +561,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float const block_q2_K * x = (const block_q2_K *)vx + ib0; + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 @@ -563,8 +577,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float const int s_offset = 8*im; const int y_offset = 128*im + l0; - float tmp = 0; // partial sum for thread in warp - uint32_t aux[4]; const uint8_t * d = (const uint8_t *)aux; const uint8_t * m = (const uint8_t *)(aux + 2); @@ -600,6 +612,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float tmp += dall * sum1 - dmin * sum2; } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; + + uint32_t uaux[2]; + const uint8_t * d = (const uint8_t *)uaux; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint32_t * s = (const uint32_t *)x[i].scales; + + uaux[0] = s[0] & 0x0f0f0f0f; + uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; + + const half2 * dh = (const half2 *)&x[i].d; + + const float2 dall = __half22float2(dh[0]); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t ql = q[l]; + sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) + + y[l+16] * d[1] * ((ql >> 2) & 3) + + y[l+32] * d[2] * ((ql >> 4) & 3) + + y[l+48] * d[3] * ((ql >> 6) & 3); + sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; + } + tmp += dall.x * sum1 - dall.y * sum2; + } +#endif // sum up partial sums and write back result __syncthreads(); @@ -1351,7 +1396,11 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q2_K<<>>(vx, y); +#else + dequantize_block_q2_K<<>>(vx, y); +#endif } static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {