block y dim

This commit is contained in:
JohannesGaessler 2023-05-19 17:24:05 +02:00
parent fbf5588abc
commit 1a787101cc
3 changed files with 30 additions and 21 deletions

View file

@ -68,6 +68,7 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework"
option(LLAMA_BLAS "llama: use BLAS" OFF)
option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic)
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
option(LLAMA_DMMV_BLOCK_Y "llama: y block size for dmmv CUDA kernels" 1)
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
@ -184,6 +185,7 @@ if (LLAMA_CUBLAS)
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
add_compile_definitions(GGML_USE_CUBLAS)
add_compile_definitions(CUDA_DMMV_BLOCK_Y=${LLAMA_DMMV_BLOCK_Y})
if (LLAMA_STATIC)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)

View file

@ -133,9 +133,14 @@ ifdef LLAMA_CUBLAS
OBJS += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
ifdef LLAMA_DMMV_BLOCK_Y
NVCCFLAGS += -DCUDA_DMMV_BLOCK_Y=$(LLAMA_DMMV_BLOCK_Y)
else
NVCCFLAGS += -DCUDA_DMMV_BLOCK_Y=1
endif # LLAMA_DMMV_BLOCK_Y
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif
endif # LLAMA_CUBLAS
ifdef LLAMA_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST
CXXFLAGS += -DGGML_USE_CLBLAST

View file

@ -85,7 +85,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
#define CUDA_MUL_BLOCK_SIZE 256
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
#define CUDA_DMMV_BLOCK_X 32 // dmmv = dequantize_mul_mat_vec
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -202,7 +202,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
const int row = blockIdx.x;
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int y_offset = qr == 1 ? 1 : qk/2;
@ -279,33 +279,35 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
}
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0);
GGML_ASSERT(nrows % CUDA_DMMV_BLOCK_Y == 0);
const dim3 block_dims(CUDA_DMMV_BLOCK_X, CUDA_DMMV_BLOCK_Y, 1);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_X, QK4_0, QR4_0, dequantize_q4_0>
<<<nrows/CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
}
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_X, QK4_1, QR4_1, dequantize_q4_1>
<<<nrows, CUDA_DMMV_BLOCK_X, 0, stream>>>(vx, y, dst, ncols);
}
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_X, QK5_0, QR5_0, dequantize_q5_0>
<<<nrows, CUDA_DMMV_BLOCK_X, 0, stream>>>(vx, y, dst, ncols);
}
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_X, QK5_1, QR5_1, dequantize_q5_1>
<<<nrows, CUDA_DMMV_BLOCK_X, 0, stream>>>(vx, y, dst, ncols);
}
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_X, QK8_0, QR8_0, dequantize_q8_0>
<<<nrows, CUDA_DMMV_BLOCK_X, 0, stream>>>(vx, y, dst, ncols);
}
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@ -314,9 +316,9 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
}
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_X, 32, 1, convert_f16>
<<<nrows, CUDA_DMMV_BLOCK_X, 0, stream>>>(vx, y, dst, ncols);
}
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {