From 98bfee013b18189a3335d6cf392fb69bbcd4e4ee Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sun, 21 May 2023 12:01:14 +0200 Subject: [PATCH] Fewer iters, more ops per iter --- CMakeLists.txt | 4 +- Makefile | 8 +-- ggml-cuda.cu | 130 +++++++++++++++++++++++-------------------------- 3 files changed, 67 insertions(+), 75 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0dff0005d..6906a849a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,8 +68,8 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework" option(LLAMA_BLAS "llama: use BLAS" OFF) option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +set(LLAMA_CUDA_BX "32" CACHE STRING "llama: x block size for dmmv CUDA kernels") set(LLAMA_CUDA_BY "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") -option(LLAMA_CUDA_UNROLL "llama: unroll loops in dmmv CUDA kernels" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) @@ -186,8 +186,8 @@ if (LLAMA_CUBLAS) set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) + add_compile_definitions(GGML_CUDA_DMMV_BLOCK_X=${LLAMA_CUDA_BX}) add_compile_definitions(GGML_CUDA_DMMV_BLOCK_Y=${LLAMA_CUDA_BY}) - add_compile_definitions(GGML_CUDA_UNROLL=${LLAMA_CUDA_UNROLL}) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) diff --git a/Makefile b/Makefile index 700868964..56d67977b 100644 --- a/Makefile +++ b/Makefile @@ -133,14 +133,16 @@ ifdef LLAMA_CUBLAS OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native +ifdef LLAMA_CUDA_BX + NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_X=$(LLAMA_CUDA_BX) +else + NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_X=32 +endif # LLAMA_CUDA_BY ifdef LLAMA_CUDA_BY NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=$(LLAMA_CUDA_BY) else NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=1 endif # LLAMA_CUDA_BY -ifdef LLAMA_CUDA_UNROLL - NVCCFLAGS += -DGGML_CUDA_UNROLL=$(LLAMA_CUDA_UNROLL) -endif # LLAMA_CUDA_UNROLL ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif # LLAMA_CUBLAS diff --git a/ggml-cuda.cu b/ggml-cuda.cu index afab6704f..6b9176697 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -83,10 +83,16 @@ typedef struct { } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); +#define WARP_SIZE 32 + #define CUDA_MUL_BLOCK_SIZE 256 + #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 + // dmmv = dequantize_mul_mat_vec -#define GGML_CUDA_DMMV_BLOCK_X 32 +#ifndef GGML_CUDA_DMMV_BLOCK_X +#define GGML_CUDA_DMMV_BLOCK_X 32 // can by set by compiler option LLAMA_CUDA_BY +#endif #ifndef GGML_CUDA_DMMV_BLOCK_Y #define GGML_CUDA_DMMV_BLOCK_Y 1 // can by set by compiler option LLAMA_CUDA_BY #endif @@ -204,32 +210,40 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) dequantize_kernel(vx, ib, iqs, v0, v1); } -template -static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst) { +template +static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; + const int iter_stride = 2*GGML_CUDA_DMMV_BLOCK_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter const int y_offset = qr == 1 ? 1 : qk/2; - float tmp = 0; // partial sum for thread in warp -#ifdef GGML_CUDA_UNROLL -#pragma unroll -#endif - for (int i = 0; i < ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*ncols + col)/qk; // block index - const int iqs = (col%qk)/qr; // quant index + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index const int iybs = col - col%qk; // y block start index - // dequantize - float v0, v1; - dequantize_kernel(vx, ib, iqs, v0, v1); +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter - // matrix multiplication - tmp += v0 * y[iybs + iqs + 0]; - tmp += v1 * y[iybs + iqs + y_offset]; + // dequantize + float v0, v1; + dequantize_kernel(vx, ib, iqs + j/qr, v0, v1); + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + + // matrix multiplication + tmp += v0 * y[iybs + iqs + j/qr + 0]; + tmp += v1 * y[iybs + iqs + j/qr + y_offset]; + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 + } } // sum up partial sums and write back result @@ -274,72 +288,44 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu dequantize_block<<>>(vx, y, k); } -template -static void dequantize_mul_mat_vec_cuda(const void * vx, const float * y, float * dst, - const int ncols, const int nrows, cudaStream_t stream) { +static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); - const dim3 block_dims(GGML_CUDA_DMMV_BLOCK_X, GGML_CUDA_DMMV_BLOCK_Y, 1); - - // Use a switch statement for ncols so the compiler can unroll all loops: - switch (ncols) { - case 4096: - dequantize_mul_mat_vec<4096, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 5120: - dequantize_mul_mat_vec<5120, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 6656: - dequantize_mul_mat_vec<6656, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 8192: - dequantize_mul_mat_vec<8192, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 11008: - dequantize_mul_mat_vec<11008, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 13824: - dequantize_mul_mat_vec<13824, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 17920: - dequantize_mul_mat_vec<17920, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - case 22016: - dequantize_mul_mat_vec<22016, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel> - <<>>(vx, y, dst); - break; - default: - fprintf(stderr, "Tell the devs to add a switch case for this: ncols=%d\n", ncols); - GGML_ASSERT(false); - break; - } -} - -static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -348,7 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c } static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - dequantize_mul_mat_vec_cuda(vx, y, dst, ncols, nrows, stream); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1); + dequantize_mul_mat_vec<1, 1, convert_f16> + <<>>(vx, y, dst, ncols); } static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {