CUDA: refactor mmq, dmmv, mmvq (#7716)
* CUDA: refactor mmq, dmmv, mmvq * fix out-of-bounds write * struct for qk, qr, qi * fix cmake build * mmq_type_traits
This commit is contained in:
parent
2b3389677a
commit
7d1a378b8f
112 changed files with 1783 additions and 1767 deletions
|
@ -422,10 +422,22 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int
|
|||
v.y = x[ib + iqs + 1];
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
|
||||
return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
|
||||
type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
|
||||
type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
|
||||
type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
|
||||
type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
|
||||
type == GGML_TYPE_F16 ? convert_f16 :
|
||||
nullptr;
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
// qk = quantized weights per x block
|
||||
// qr = number of quantized weights per data value in x block
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
|
||||
constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
|
||||
constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
|
||||
|
||||
const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
|
||||
|
||||
if (row >= nrows) {
|
||||
|
@ -493,7 +505,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
|
|||
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
|
||||
dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
|
@ -502,7 +514,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
|
|||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
|
||||
dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
|
@ -511,7 +523,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
|
|||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
|
||||
dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
|
@ -520,7 +532,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
|
|||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
|
||||
dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
|
@ -529,7 +541,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y,
|
|||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
|
||||
dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
|
@ -580,7 +592,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
|
|||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<1, 1, convert_f16>
|
||||
dequantize_mul_mat_vec<GGML_TYPE_F16>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue