dequantize_mul_mat_vec kernels for q5_1, q8_0, f16
This commit is contained in:
parent
5a0ecf768d
commit
bb0993ed48
1 changed files with 81 additions and 6 deletions
87
ggml-cuda.cu
87
ggml-cuda.cu
|
@ -36,7 +36,11 @@ typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs
|
|||
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||
typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
|
||||
// QK = number of values after dequantization
|
||||
// QR = QK / number of values before dequantization
|
||||
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||
|
@ -44,6 +48,7 @@ typedef struct {
|
|||
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||
|
||||
#define QK4_1 32
|
||||
#define QR4_1 2
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
float m; // min
|
||||
|
@ -52,6 +57,7 @@ typedef struct {
|
|||
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
|
||||
|
||||
#define QK5_0 32
|
||||
#define QR5_0 2
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
uint8_t qh[4]; // 5-th bit of quants
|
||||
|
@ -60,6 +66,7 @@ typedef struct {
|
|||
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
|
||||
|
||||
#define QK5_1 32
|
||||
#define QR5_1 2
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
|
@ -69,6 +76,7 @@ typedef struct {
|
|||
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
||||
|
||||
#define QK8_0 32
|
||||
#define QR8_0 1
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
int8_t qs[QK8_0]; // quants
|
||||
|
@ -124,6 +132,44 @@ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int
|
|||
v1 = x1*d;
|
||||
}
|
||||
|
||||
static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||
|
||||
const float d = x[ib].d;
|
||||
const float m = x[ib].m;
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||
|
||||
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||
|
||||
const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
|
||||
|
||||
v0 = x0*d + m;
|
||||
v1 = x1*d + m;
|
||||
}
|
||||
|
||||
static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||
|
||||
const float d = x[ib].d;
|
||||
|
||||
const int8_t vi0 = x[ib].qs[iqs + 0];
|
||||
const int8_t vi1 = x[ib].qs[iqs + 1];
|
||||
|
||||
v0 = vi0*d;
|
||||
v1 = vi1*d;
|
||||
}
|
||||
|
||||
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
v0 = __half2float(x[ib + 0]);
|
||||
v1 = __half2float(x[ib + 1]);
|
||||
}
|
||||
|
||||
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
|
||||
static const int qk = QK4_0;
|
||||
|
||||
|
@ -224,18 +270,20 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
|
|||
}
|
||||
}
|
||||
|
||||
template <int block_size, int qk, dequantize_kernel_t dequantize_kernel>
|
||||
template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
__shared__ float tmp[block_size]; // separate sum for each thread
|
||||
tmp[tid] = 0;
|
||||
|
||||
for (int i = 0; i < ncols/block_size; i += 2) {
|
||||
const int col = i*block_size + 2*tid;
|
||||
const int ib = (row*ncols + col)/qk; // block index
|
||||
const int iqs = (col%qk)/2; // quant index
|
||||
const int iqs = (col%qk)/qr; // quant index
|
||||
const int iybs = col - col%qk; // y block start index
|
||||
|
||||
// dequantize
|
||||
|
@ -244,7 +292,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
|||
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||
tmp[tid] += v1 * y[iybs + iqs + qk/2];
|
||||
tmp[tid] += v1 * y[iybs + iqs + y_offset];
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
|
@ -287,17 +335,32 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
|
|||
|
||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, dequantize_q4_0><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
|
||||
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, dequantize_q4_1><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
|
||||
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, dequantize_q5_0><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
|
||||
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
|
||||
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
|
||||
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
// TODO: optimize
|
||||
|
@ -313,6 +376,12 @@ static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStre
|
|||
convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
|
||||
}
|
||||
|
||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
|
||||
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
|
@ -340,6 +409,12 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t
|
|||
return dequantize_mul_mat_vec_q4_1_cuda;
|
||||
case GGML_TYPE_Q5_0:
|
||||
return dequantize_mul_mat_vec_q5_0_cuda;
|
||||
case GGML_TYPE_Q5_1:
|
||||
return dequantize_mul_mat_vec_q5_1_cuda;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return dequantize_mul_mat_vec_q8_0_cuda;
|
||||
case GGML_TYPE_F16:
|
||||
return dequantize_mul_mat_vec_q8_0_cuda;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue