diff --git a/CMakeLists.txt b/CMakeLists.txt index 6c43c5ec5..b9936268e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,7 +68,8 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework" option(LLAMA_BLAS "llama: use BLAS" OFF) option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) -option(LLAMA_DMMV_BLOCK_Y "llama: y block size for dmmv CUDA kernels" 1) +option(LLAMA_CUDA_BY "llama: y block size for dmmv CUDA kernels" 1) +option(LLAMA_CUDA_UNROLL "llama: unroll loops in dmmv CUDA kernels" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) @@ -185,7 +186,8 @@ if (LLAMA_CUBLAS) set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) - add_compile_definitions(CUDA_DMMV_BLOCK_Y=${LLAMA_DMMV_BLOCK_Y}) + add_compile_definitions(GGML_CUDA_DMMV_BLOCK_Y=${LLAMA_CUDA_BY}) + add_compile_definitions(GGML_CUDA_UNROLL=${LLAMA_CUDA_UNROLL}) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) diff --git a/Makefile b/Makefile index 3e46b2b5c..700868964 100644 --- a/Makefile +++ b/Makefile @@ -133,11 +133,14 @@ ifdef LLAMA_CUBLAS OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native -ifdef LLAMA_DMMV_BLOCK_Y - NVCCFLAGS += -DCUDA_DMMV_BLOCK_Y=$(LLAMA_DMMV_BLOCK_Y) +ifdef LLAMA_CUDA_BY + NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=$(LLAMA_CUDA_BY) else - NVCCFLAGS += -DCUDA_DMMV_BLOCK_Y=1 -endif # LLAMA_DMMV_BLOCK_Y + NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=1 +endif # LLAMA_CUDA_BY +ifdef LLAMA_CUDA_UNROLL + NVCCFLAGS += -DGGML_CUDA_UNROLL=$(LLAMA_CUDA_UNROLL) +endif # LLAMA_CUDA_UNROLL ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif # LLAMA_CUBLAS diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c9366705d..9fb5f50fe 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -85,7 +85,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo #define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 -#define CUDA_DMMV_BLOCK_X 32 // dmmv = dequantize_mul_mat_vec +#define GGML_CUDA_DMMV_BLOCK_X 32 // dmmv = dequantize_mul_mat_vec static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -200,8 +200,8 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) dequantize_kernel(vx, ib, iqs, v0, v1); } -template -static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { +template +static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; @@ -210,6 +210,9 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float tmp = 0; // partial sum for thread in warp +#ifdef GGML_CUDA_UNROLL +#pragma unroll +#endif for (int i = 0; i < ncols/block_size; i += 2) { const int col = i*block_size + 2*tid; const int ib = (row*ncols + col)/qk; // block index @@ -238,6 +241,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, } tmp = tmpa[0]; // now full sum #else +#pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } @@ -278,36 +282,72 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu dequantize_block<<>>(vx, y, k); } +template +static void dequantize_mul_mat_vec_cuda(const void * vx, const float * y, float * dst, + const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(GGML_CUDA_DMMV_BLOCK_X, GGML_CUDA_DMMV_BLOCK_Y, 1); + + // Use a switch statement for ncols so the compiler can unroll all loops: + switch (ncols) { + case 4096: + dequantize_mul_mat_vec<4096, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 5120: + dequantize_mul_mat_vec<5120, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 6656: + dequantize_mul_mat_vec<6656, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 8192: + dequantize_mul_mat_vec<8192, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 11008: + dequantize_mul_mat_vec<11008, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 13824: + dequantize_mul_mat_vec<13824, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 17920: + dequantize_mul_mat_vec<17920, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + case 22016: + dequantize_mul_mat_vec<22016, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> + <<>>(vx, y, dst); + break; + default: + fprintf(stderr, "Tell the devs to add a switch case for this: ncols=%d\n", ncols); + GGML_ASSERT(false); + break; + } +} + static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); - GGML_ASSERT(nrows % CUDA_DMMV_BLOCK_Y == 0); - const dim3 block_dims(CUDA_DMMV_BLOCK_X, CUDA_DMMV_BLOCK_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); } static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); } static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); } static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); } static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -316,9 +356,7 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c } static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); } static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {