diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 05a7a8a09..f0d783418 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -262,6 +262,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_QUANTIZE_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 +#ifndef GGML_CUDA_MMQ_Y +#define GGML_CUDA_MMQ_Y 128 +#endif + // dmmv = dequantize_mul_mat_vec #ifndef GGML_CUDA_DMMV_X #define GGML_CUDA_DMMV_X 32 @@ -1393,8 +1397,8 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1( static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)]; - __shared__ half2 tile_x_d[(2*WARP_SIZE) * (WARP_SIZE/QI4_0)]; + __shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_0)]; *x_ql = tile_x_qs; *x_dm = tile_x_d; @@ -1467,8 +1471,8 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)]; - __shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI4_1)]; + __shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_1)]; *x_ql = tile_x_qs; *x_dm = tile_x_dm; @@ -1539,9 +1543,9 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1( static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)]; - __shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE/QI5_0)]; - __shared__ half2 tile_x_d[(2*WARP_SIZE) * (WARP_SIZE/QI5_0)]; + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0)]; + __shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0)]; *x_ql = tile_x_ql; *x_qh = tile_x_qh; @@ -1625,9 +1629,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1( static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)]; - __shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE/QI5_1)]; - __shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI5_1)]; + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1)]; *x_ql = tile_x_ql; *x_qh = tile_x_qh; @@ -1688,8 +1692,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)]; - __shared__ half2 tile_x_d[(2*WARP_SIZE) * (WARP_SIZE/QI8_0)]; + __shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI8_0)]; *x_ql = tile_x_qs; *x_dm = tile_x_d; @@ -1772,9 +1776,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1( static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)]; - __shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE / QI2_K)]; - __shared__ int tile_x_sc[(2*WARP_SIZE) * (WARP_SIZE / 4)]; + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE / QI2_K)]; + __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE / 4)]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1799,7 +1803,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - __builtin_assume(i < 2*WARP_SIZE); + __builtin_assume(i < GGML_CUDA_MMQ_Y); __builtin_assume(j < WARP_SIZE); __builtin_assume(k < WARP_SIZE); @@ -1888,10 +1892,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)]; - __shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE / QI2_K)]; - __shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE / 2)]; - __shared__ int tile_x_sc[(2*WARP_SIZE) * (WARP_SIZE / 4)]; + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE / QI2_K)]; + __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE / 2)]; + __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE / 4)]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -2064,9 +2068,9 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)]; - __shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI4_K)]; - __shared__ int tile_x_sc[(2*WARP_SIZE) * (3*WARP_SIZE/32)]; + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K)]; + __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (3*WARP_SIZE/32)]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -2091,7 +2095,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - __builtin_assume(i < 2*WARP_SIZE); + __builtin_assume(i < GGML_CUDA_MMQ_Y); __builtin_assume(j < WARP_SIZE); __builtin_assume(k < WARP_SIZE); @@ -2256,10 +2260,10 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)]; - __shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI4_K)]; - __shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE/4)]; - __shared__ int tile_x_sc[(2*WARP_SIZE) * (3*WARP_SIZE/32)]; + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K)]; + __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/4)]; + __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (3*WARP_SIZE/32)]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -2383,10 +2387,10 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)]; - __shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI6_K)]; - __shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE/2)]; - __shared__ int tile_x_sc[(2*WARP_SIZE) * (WARP_SIZE/8)]; + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI6_K)]; + __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/2)]; + __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -2413,7 +2417,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - __builtin_assume(i < 2*WARP_SIZE); + __builtin_assume(i < GGML_CUDA_MMQ_Y); __builtin_assume(j < WARP_SIZE); __builtin_assume(k < WARP_SIZE); @@ -2459,7 +2463,7 @@ static __global__ void mul_mat_q( const int tid_x = threadIdx.x; const int tid_y = threadIdx.y; - const int row_dst_0 = 2*blockIdx.x*WARP_SIZE; + const int row_dst_0 = blockIdx.x*GGML_CUDA_MMQ_Y; const int & row_x_0 = row_dst_0; const int row_dst = row_dst_0 + tid_x; @@ -2478,11 +2482,11 @@ static __global__ void mul_mat_q( __shared__ int tile_y_qs[(WARP_SIZE) * (qr*WARP_SIZE)]; __shared__ half2 tile_y_ds[(WARP_SIZE) * blocks_per_tile_y_col]; - float sum[2][4] = {0.0f}; + float sum[GGML_CUDA_MMQ_Y/WARP_SIZE][4] = {0.0f}; for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { - for (int i = 0; i < 2*WARP_SIZE; i += 8) { + for (int i = 0; i < GGML_CUDA_MMQ_Y; i += 8) { load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, i + tid_y, tid_x, blocks_per_row_x); } @@ -2516,7 +2520,7 @@ static __global__ void mul_mat_q( #pragma unroll for (int j = 0; j < WARP_SIZE; j += 8) { #pragma unroll - for (int i = 0; i < 2*WARP_SIZE; i += WARP_SIZE) { + for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) { sum[i/WARP_SIZE][j/8] += vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, tid_x + i, tid_y + j, k); } @@ -2538,8 +2542,9 @@ static __global__ void mul_mat_q( return; } - dst[col_dst*nrows_dst + row_dst] = sum[0][j/8]; - dst[col_dst*nrows_dst + row_dst + WARP_SIZE] = sum[1][j/8]; + for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) { + dst[col_dst*nrows_dst + row_dst + i] = sum[i/WARP_SIZE][j/8]; + } } } @@ -3239,7 +3244,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); @@ -3251,7 +3256,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); @@ -3263,7 +3268,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); @@ -3275,7 +3280,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); @@ -3287,7 +3292,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); @@ -3299,7 +3304,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); @@ -3311,7 +3316,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); @@ -3323,7 +3328,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); @@ -3335,7 +3340,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); @@ -3347,7 +3352,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);