Fixed LLAMA_CUDA_DMMV_Y > 1 for WizardLM
This commit is contained in:
parent
a47072b85d
commit
b783da97a6
1 changed files with 55 additions and 30 deletions
85
ggml-cuda.cu
85
ggml-cuda.cu
|
@ -660,10 +660,15 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
|
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
|
||||||
// qk = quantized weights per x block
|
// qk = quantized weights per x block
|
||||||
// qr = number of quantized weights per data value in x block
|
// qr = number of quantized weights per data value in x block
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
|
if (row >= nrows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
const int iter_stride = 2*GGML_CUDA_DMMV_X;
|
const int iter_stride = 2*GGML_CUDA_DMMV_X;
|
||||||
|
@ -708,8 +713,13 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int n_thread, dot_kernel_k_t dot_kernel>
|
template <int n_thread, dot_kernel_k_t dot_kernel>
|
||||||
static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols) {
|
static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
|
if (row >= nrows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
const int iter_stride = QK_K;
|
const int iter_stride = QK_K;
|
||||||
|
@ -1035,73 +1045,92 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
|
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
|
||||||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
|
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
|
||||||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
|
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
|
||||||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
|
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
|
||||||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
|
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
|
||||||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int ny = 2;
|
const int ny = 2;
|
||||||
|
const int block_num_y = (nrows + ny - 1) / ny;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(32, ny, 1);
|
const dim3 block_dims(32, ny, 1);
|
||||||
dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<(nrows + ny - 1)/ny, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const dim3 block_dims(32, 2, 1);
|
const int ny = 2;
|
||||||
dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
const int block_num_y = (nrows + ny - 1) / ny;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
const dim3 block_dims(32, ny, 1);
|
||||||
|
dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const dim3 block_dims(32, 2, 1);
|
const int ny = 2;
|
||||||
dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
const int block_num_y = (nrows + ny - 1) / ny;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
const dim3 block_dims(32, ny, 1);
|
||||||
|
dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const dim3 block_dims(32, 2, 1);
|
const int ny = 2;
|
||||||
dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
const int block_num_y = (nrows + ny - 1) / ny;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
const dim3 block_dims(32, ny, 1);
|
||||||
|
dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const dim3 block_dims(32, 2, 1);
|
const int ny = 2;
|
||||||
dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
const int block_num_y = (nrows + ny - 1) / ny;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
const dim3 block_dims(32, ny, 1);
|
||||||
|
dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||||
|
@ -1111,10 +1140,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
|
||||||
|
|
||||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||||
dequantize_mul_mat_vec<1, 1, convert_f16>
|
dequantize_mul_mat_vec<1, 1, convert_f16>
|
||||||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||||
|
@ -1798,9 +1828,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
||||||
int64_t row_low, row_high;
|
int64_t row_low, row_high;
|
||||||
if (split) {
|
if (split) {
|
||||||
row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
|
row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
|
||||||
row_low -= row_low % GGML_CUDA_DMMV_Y;
|
|
||||||
row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
|
row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
|
||||||
row_high -= row_high % GGML_CUDA_DMMV_Y;
|
|
||||||
} else {
|
} else {
|
||||||
row_low = 0;
|
row_low = 0;
|
||||||
row_high = nrows0;
|
row_high = nrows0;
|
||||||
|
@ -2223,10 +2251,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
|
||||||
row_high = nrows;
|
row_high = nrows;
|
||||||
} else if (backend == GGML_BACKEND_GPU_SPLIT) {
|
} else if (backend == GGML_BACKEND_GPU_SPLIT) {
|
||||||
row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
|
row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
|
||||||
row_low -= row_low % GGML_CUDA_DMMV_Y;
|
|
||||||
row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
|
row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
|
||||||
row_high -= row_high % GGML_CUDA_DMMV_Y;
|
|
||||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue