ggml-cuda : remove device side dequantize
This commit is contained in:
parent
09e35d04b1
commit
6963441b22
1 changed files with 7 additions and 7 deletions
14
ggml-cuda.cu
14
ggml-cuda.cu
|
@ -5224,13 +5224,13 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
|
|||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __host__ __device__ void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
||||
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __host__ __device__ void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
#if QK_K == 256
|
||||
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
|
@ -5240,7 +5240,7 @@ static __host__ __device__ void dequantize_row_q2_K_cuda(const void * vx, dst_t
|
|||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __host__ __device__ void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
#if QK_K == 256
|
||||
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
|
@ -5250,13 +5250,13 @@ static __host__ __device__ void dequantize_row_q3_K_cuda(const void * vx, dst_t
|
|||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __host__ __device__ void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __host__ __device__ void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
#if QK_K == 256
|
||||
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
|
@ -5266,7 +5266,7 @@ static __host__ __device__ void dequantize_row_q5_K_cuda(const void * vx, dst_t
|
|||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __host__ __device__ void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
#if QK_K == 256
|
||||
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
|
@ -5275,7 +5275,7 @@ static __host__ __device__ void dequantize_row_q6_K_cuda(const void * vx, dst_t
|
|||
#endif
|
||||
}
|
||||
|
||||
static to_fp16_cuda_t __host__ __device__ ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue