diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a241ae84f..d6387fd92 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -225,21 +225,24 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) { } } -static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, int ncols, int nrows) { +template static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) { const block_q4_0 * x = (const block_q4_0 *) vx; - const int row = blockIdx.x*blockDim.x + threadIdx.x; + const int row = blockIdx.x; + const int tid = threadIdx.x; - if (row >= nrows) { - return; - } - dst[row] = 0; - for (int i = 0; i < ncols; i += 2) { - const float d = x[(row*ncols + i)/QK4_0].d; + __shared__ float tmp[block_size]; // separate sum for each thread + tmp[tid] = 0; - const uint8_t * pp = x[(row*ncols + i)/QK4_0].qs; + for (int i = 0; i < ncols/block_size; i += 2) { + const int col = i*block_size + 2*tid; - const uint8_t vui = pp[((row*ncols + i)%QK4_0)/2]; + // dequantize + const float d = x[(row*ncols + col)/QK4_0].d; + + const uint8_t * pp = x[(row*ncols + col)/QK4_0].qs; + + const uint8_t vui = pp[((row*ncols + col)%QK4_0)/2]; const int8_t vi0 = vui & 0xF; const int8_t vi1 = vui >> 4; @@ -247,8 +250,20 @@ static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, const float v0 = (vi0 - 8)*d; const float v1 = (vi1 - 8)*d; - dst[row] += v0 * y[i + 0]; - dst[row] += v1 * y[i + 1]; + // matrix multiplication + tmp[tid] += v0 * y[col + 0]; + tmp[tid] += v1 * y[col + 1]; + } + + // sum up partial sums and write back result + for (int s=block_size/2; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + __syncthreads(); + } + if (tid == 0) { + dst[row] = tmp[0]; } } @@ -282,15 +297,21 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre dequantize_block_q8_0<<>>(vx, y); } -static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, int ncols, int nrows, cudaStream_t stream) { - static int block_size = -1; - if (block_size == -1) { - int min_grid_size; - CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_mul_mat_q4_0, 0, 0)); - block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); - } - const int grid_size = (nrows + block_size - 1) / block_size; // Round up. - dequantize_mul_mat_q4_0<<>>(vx, y, dst, ncols, nrows); +static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + // static int block_size = -1; + // if (block_size == -1) { + // int min_grid_size, max_block_size = 1; + // CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &max_block_size, dequantize_mul_mat_q4_0<256>, 0, 0)); + // max_block_size = min(max_block_size, GGML_CUDA_MAX_BLOCK_SIZE); + // block_size = 1; + // while (block_size*2 <= max_block_size && block_size*2 % ncols == 0) { + // block_size *= 2; + // } + // } + // dequantize_mul_mat_q4_0<<>>(vx, y, dst, ncols); + const int block_size = 32; + GGML_ASSERT(ncols % block_size == 0); + dequantize_mul_mat_q4_0<<>>(vx, y, dst, ncols); } // TODO: optimize