loop unrolling
This commit is contained in:
parent
1a787101cc
commit
82cf01f897
3 changed files with 72 additions and 29 deletions
|
@ -68,7 +68,8 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework"
|
||||||
option(LLAMA_BLAS "llama: use BLAS" OFF)
|
option(LLAMA_BLAS "llama: use BLAS" OFF)
|
||||||
option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic)
|
option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic)
|
||||||
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
|
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
|
||||||
option(LLAMA_DMMV_BLOCK_Y "llama: y block size for dmmv CUDA kernels" 1)
|
option(LLAMA_CUDA_BY "llama: y block size for dmmv CUDA kernels" 1)
|
||||||
|
option(LLAMA_CUDA_UNROLL "llama: unroll loops in dmmv CUDA kernels" OFF)
|
||||||
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
||||||
|
|
||||||
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||||
|
@ -185,7 +186,8 @@ if (LLAMA_CUBLAS)
|
||||||
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
|
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
|
||||||
|
|
||||||
add_compile_definitions(GGML_USE_CUBLAS)
|
add_compile_definitions(GGML_USE_CUBLAS)
|
||||||
add_compile_definitions(CUDA_DMMV_BLOCK_Y=${LLAMA_DMMV_BLOCK_Y})
|
add_compile_definitions(GGML_CUDA_DMMV_BLOCK_Y=${LLAMA_CUDA_BY})
|
||||||
|
add_compile_definitions(GGML_CUDA_UNROLL=${LLAMA_CUDA_UNROLL})
|
||||||
|
|
||||||
if (LLAMA_STATIC)
|
if (LLAMA_STATIC)
|
||||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||||
|
|
11
Makefile
11
Makefile
|
@ -133,11 +133,14 @@ ifdef LLAMA_CUBLAS
|
||||||
OBJS += ggml-cuda.o
|
OBJS += ggml-cuda.o
|
||||||
NVCC = nvcc
|
NVCC = nvcc
|
||||||
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
|
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
|
||||||
ifdef LLAMA_DMMV_BLOCK_Y
|
ifdef LLAMA_CUDA_BY
|
||||||
NVCCFLAGS += -DCUDA_DMMV_BLOCK_Y=$(LLAMA_DMMV_BLOCK_Y)
|
NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=$(LLAMA_CUDA_BY)
|
||||||
else
|
else
|
||||||
NVCCFLAGS += -DCUDA_DMMV_BLOCK_Y=1
|
NVCCFLAGS += -DGGML_CUDA_DMMV_BLOCK_Y=1
|
||||||
endif # LLAMA_DMMV_BLOCK_Y
|
endif # LLAMA_CUDA_BY
|
||||||
|
ifdef LLAMA_CUDA_UNROLL
|
||||||
|
NVCCFLAGS += -DGGML_CUDA_UNROLL=$(LLAMA_CUDA_UNROLL)
|
||||||
|
endif # LLAMA_CUDA_UNROLL
|
||||||
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
||||||
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
||||||
endif # LLAMA_CUBLAS
|
endif # LLAMA_CUBLAS
|
||||||
|
|
84
ggml-cuda.cu
84
ggml-cuda.cu
|
@ -85,7 +85,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
|
||||||
|
|
||||||
#define CUDA_MUL_BLOCK_SIZE 256
|
#define CUDA_MUL_BLOCK_SIZE 256
|
||||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||||
#define CUDA_DMMV_BLOCK_X 32 // dmmv = dequantize_mul_mat_vec
|
#define GGML_CUDA_DMMV_BLOCK_X 32 // dmmv = dequantize_mul_mat_vec
|
||||||
|
|
||||||
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
@ -200,8 +200,8 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
|
||||||
dequantize_kernel(vx, ib, iqs, v0, v1);
|
dequantize_kernel(vx, ib, iqs, v0, v1);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int ncols, int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
|
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst) {
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
|
@ -210,6 +210,9 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
||||||
|
|
||||||
float tmp = 0; // partial sum for thread in warp
|
float tmp = 0; // partial sum for thread in warp
|
||||||
|
|
||||||
|
#ifdef GGML_CUDA_UNROLL
|
||||||
|
#pragma unroll
|
||||||
|
#endif
|
||||||
for (int i = 0; i < ncols/block_size; i += 2) {
|
for (int i = 0; i < ncols/block_size; i += 2) {
|
||||||
const int col = i*block_size + 2*tid;
|
const int col = i*block_size + 2*tid;
|
||||||
const int ib = (row*ncols + col)/qk; // block index
|
const int ib = (row*ncols + col)/qk; // block index
|
||||||
|
@ -238,6 +241,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
||||||
}
|
}
|
||||||
tmp = tmpa[0]; // now full sum
|
tmp = tmpa[0]; // now full sum
|
||||||
#else
|
#else
|
||||||
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
}
|
}
|
||||||
|
@ -278,36 +282,72 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
|
||||||
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<dequantize_kernel_t dequantize_kernel, int qk, int qr>
|
||||||
|
static void dequantize_mul_mat_vec_cuda(const void * vx, const float * y, float * dst,
|
||||||
|
const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_BLOCK_X == 0);
|
||||||
|
GGML_ASSERT(nrows % GGML_CUDA_DMMV_BLOCK_Y == 0);
|
||||||
|
const dim3 block_dims(GGML_CUDA_DMMV_BLOCK_X, GGML_CUDA_DMMV_BLOCK_Y, 1);
|
||||||
|
|
||||||
|
// Use a switch statement for ncols so the compiler can unroll all loops:
|
||||||
|
switch (ncols) {
|
||||||
|
case 4096:
|
||||||
|
dequantize_mul_mat_vec<4096, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
|
||||||
|
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
|
||||||
|
break;
|
||||||
|
case 5120:
|
||||||
|
dequantize_mul_mat_vec<5120, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
|
||||||
|
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
|
||||||
|
break;
|
||||||
|
case 6656:
|
||||||
|
dequantize_mul_mat_vec<6656, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
|
||||||
|
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
|
||||||
|
break;
|
||||||
|
case 8192:
|
||||||
|
dequantize_mul_mat_vec<8192, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
|
||||||
|
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
|
||||||
|
break;
|
||||||
|
case 11008:
|
||||||
|
dequantize_mul_mat_vec<11008, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
|
||||||
|
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
|
||||||
|
break;
|
||||||
|
case 13824:
|
||||||
|
dequantize_mul_mat_vec<13824, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
|
||||||
|
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
|
||||||
|
break;
|
||||||
|
case 17920:
|
||||||
|
dequantize_mul_mat_vec<17920, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
|
||||||
|
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
|
||||||
|
break;
|
||||||
|
case 22016:
|
||||||
|
dequantize_mul_mat_vec<22016, GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
|
||||||
|
<<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
fprintf(stderr, "Tell the devs to add a switch case for this: ncols=%d\n", ncols);
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0);
|
dequantize_mul_mat_vec_cuda<dequantize_q4_0, QK4_0, QR4_0>(vx, y, dst, ncols, nrows, stream);
|
||||||
GGML_ASSERT(nrows % CUDA_DMMV_BLOCK_Y == 0);
|
|
||||||
const dim3 block_dims(CUDA_DMMV_BLOCK_X, CUDA_DMMV_BLOCK_Y, 1);
|
|
||||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_X, QK4_0, QR4_0, dequantize_q4_0>
|
|
||||||
<<<nrows/CUDA_DMMV_BLOCK_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0);
|
dequantize_mul_mat_vec_cuda<dequantize_q4_1, QK4_1, QR4_1>(vx, y, dst, ncols, nrows, stream);
|
||||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_X, QK4_1, QR4_1, dequantize_q4_1>
|
|
||||||
<<<nrows, CUDA_DMMV_BLOCK_X, 0, stream>>>(vx, y, dst, ncols);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0);
|
dequantize_mul_mat_vec_cuda<dequantize_q5_0, QK5_0, QR5_0>(vx, y, dst, ncols, nrows, stream);
|
||||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_X, QK5_0, QR5_0, dequantize_q5_0>
|
|
||||||
<<<nrows, CUDA_DMMV_BLOCK_X, 0, stream>>>(vx, y, dst, ncols);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0);
|
dequantize_mul_mat_vec_cuda<dequantize_q5_1, QK5_1, QR5_1>(vx, y, dst, ncols, nrows, stream);
|
||||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_X, QK5_1, QR5_1, dequantize_q5_1>
|
|
||||||
<<<nrows, CUDA_DMMV_BLOCK_X, 0, stream>>>(vx, y, dst, ncols);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0);
|
dequantize_mul_mat_vec_cuda<dequantize_q8_0, QK8_0, QR8_0>(vx, y, dst, ncols, nrows, stream);
|
||||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_X, QK8_0, QR8_0, dequantize_q8_0>
|
|
||||||
<<<nrows, CUDA_DMMV_BLOCK_X, 0, stream>>>(vx, y, dst, ncols);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||||
|
@ -316,9 +356,7 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
|
||||||
}
|
}
|
||||||
|
|
||||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_X == 0);
|
dequantize_mul_mat_vec_cuda<convert_f16, 1, 1>(vx, y, dst, ncols, nrows, stream);
|
||||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_X, 32, 1, convert_f16>
|
|
||||||
<<<nrows, CUDA_DMMV_BLOCK_X, 0, stream>>>(vx, y, dst, ncols);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue