diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 84c5e8b7d..a59110291 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -255,8 +255,8 @@ static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs uint8_t m = 1 << (4*n + j); const float d = x[ib].d; - const uint8_t * q = x[ib].qs + 32*n; - const uint8_t * hm = x[ib].hmask; + const uint8_t * q = x[ib].qs + 32*n + l; + const uint8_t * hm = x[ib].hmask + l; int8_t us = is < 4 ? (x[ib].scales[is-0] & 0xF) | (((x[ib].scales[is+8] >> 0) & 3) << 4) : is < 8 ? (x[ib].scales[is-0] & 0xF) | (((x[ib].scales[is+4] >> 2) & 3) << 4) : @@ -266,8 +266,8 @@ static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs float sum = 0; for (int k = 0; k < 8; ++k) { - int8_t ql = (q[l + k] >> shift) & 3; - sum += y[iqs + k] * (ql - ((hm[l + k] & m) ? 0 : 4)); + int8_t ql = (q[k] >> shift) & 3; + sum += y[iqs + k] * (ql - ((hm[k] & m) ? 0 : 4)); } result = sum * scale; } @@ -362,13 +362,9 @@ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y const int iqs = col%QK_K; // x quant index const int iybs = col - col%QK_K; // y block start index -// processing >2 values per i iter is faster for fast GPUs -#pragma unroll - for (int j = 0; j < vals_per_iter; j += 8) { - float v; - dot_kernel(vx, ib, iqs + j, y + iybs, v); - tmp += v; - } + float v; + dot_kernel(vx, ib, iqs, y + iybs, v); + tmp += v; } // sum up partial sums and write back result @@ -460,8 +456,8 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(WARP_SIZE, 1, 1); - dequantize_mul_mat_vec_k<<>>(vx, y, dst, ncols); + const dim3 block_dims(WARP_SIZE, 2, 1); + dequantize_mul_mat_vec_k<<>>(vx, y, dst, ncols); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {