Added q4_1 via template
This commit is contained in:
parent
637be12f16
commit
12fc292ee6
1 changed files with 57 additions and 31 deletions
88
ggml-cuda.cu
88
ggml-cuda.cu
|
@ -32,7 +32,9 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
|
||||
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||
typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
|
||||
#define QK4_0 32
|
||||
typedef struct {
|
||||
|
@ -73,6 +75,37 @@ typedef struct {
|
|||
} block_q8_0;
|
||||
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
||||
|
||||
#define CUDA_DMMV_BLOCK_SIZE 32
|
||||
|
||||
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
const float d = x[ib].d;
|
||||
|
||||
const uint8_t vui = x[ib].qs[iqs];
|
||||
|
||||
const int8_t vi0 = vui & 0xF;
|
||||
const int8_t vi1 = vui >> 4;
|
||||
|
||||
v0 = (vi0 - 8)*d;
|
||||
v1 = (vi1 - 8)*d;
|
||||
}
|
||||
|
||||
static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||
|
||||
const float d = x[ib].d;
|
||||
const float m = x[ib].m;
|
||||
|
||||
const uint8_t vui = x[ib].qs[iqs];
|
||||
|
||||
const int8_t vi0 = vui & 0xF;
|
||||
const int8_t vi1 = vui >> 4;
|
||||
|
||||
v0 = vi0*d + m;
|
||||
v1 = vi1*d + m;
|
||||
}
|
||||
|
||||
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
|
||||
static const int qk = QK4_0;
|
||||
|
||||
|
@ -173,10 +206,7 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
|
|||
}
|
||||
}
|
||||
|
||||
template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
const int qk = QK4_0;
|
||||
|
||||
template <int block_size, int qk, dequantize_kernel_t dequantize_kernel> static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
|
@ -190,17 +220,8 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
|
|||
const int iybs = col - col%qk; // y block start index
|
||||
|
||||
// dequantize
|
||||
const float d = x[ib].d;
|
||||
|
||||
const uint8_t * pp = x[ib].qs;
|
||||
|
||||
const uint8_t vui = pp[iqs];
|
||||
|
||||
const int8_t vi0 = vui & 0xF;
|
||||
const int8_t vi1 = vui >> 4;
|
||||
|
||||
const float v0 = (vi0 - 8)*d;
|
||||
const float v1 = (vi1 - 8)*d;
|
||||
float v0, v1;
|
||||
dequantize_kernel(vx, ib, iqs, v0, v1);
|
||||
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||
|
@ -244,21 +265,14 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
|
|||
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
// static int block_size = -1;
|
||||
// if (block_size == -1) {
|
||||
// int min_grid_size, max_block_size = 1;
|
||||
// CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &max_block_size, dequantize_mul_mat_q4_0<256>, 0, 0));
|
||||
// max_block_size = min(max_block_size, GGML_CUDA_MAX_BLOCK_SIZE);
|
||||
// block_size = 1;
|
||||
// while (block_size*2 <= max_block_size && block_size*2 % ncols == 0) {
|
||||
// block_size *= 2;
|
||||
// }
|
||||
// }
|
||||
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
|
||||
const int block_size = 32;
|
||||
GGML_ASSERT(ncols % block_size == 0);
|
||||
dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
|
||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, dequantize_q4_0><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, dequantize_q4_1><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
// TODO: optimize
|
||||
|
@ -293,6 +307,17 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|||
}
|
||||
}
|
||||
|
||||
static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return dequantize_mul_mat_vec_q4_0_cuda;
|
||||
case GGML_TYPE_Q4_1:
|
||||
return dequantize_mul_mat_vec_q4_1_cuda;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// buffer pool for cuda
|
||||
#define MAX_CUDA_BUFFERS 256
|
||||
|
||||
|
@ -610,6 +635,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
|
||||
dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
|
||||
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
|
@ -641,7 +667,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||
|
||||
// compute
|
||||
dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
|
||||
dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
} else {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue