ggml : add Q8_0 quantization format (rename the old one to Q8_1) (ARM NEON) (#1179)
* ggml : add Q8_0 quantization format (rename the old one to Q8_1) * tests : fix test-quantize-fns * ggml : finalize Q8_0 implementation * ggml : use q4_0_q8_0 and q4_2_q8_0 * ggml : fix Q8_0 dot product bug (ARM) * ggml : Q8_0 unroll x2 * ggml : fix bug - using wrong block type * ggml : extend quantize_fns_t with "vec_dot_type" * ggml : fix Q8_0 to use 255 values out of 256 * ggml : fix assert using wrong QK4_2 instead of QK4_3
This commit is contained in:
parent
dd0eabc049
commit
7a32fcb3b2
8 changed files with 312 additions and 147 deletions
28
ggml-cuda.cu
28
ggml-cuda.cu
|
@ -37,6 +37,13 @@ typedef struct {
|
|||
} block_q4_3;
|
||||
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
|
||||
|
||||
#define QK8_0 32
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
int8_t qs[QK8_0]; // quants
|
||||
} block_q8_0;
|
||||
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
||||
|
||||
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
|
@ -131,6 +138,22 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
|
|||
}
|
||||
}
|
||||
|
||||
static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
|
||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
|
||||
const float d = x[i].d;
|
||||
|
||||
const int8_t * pp = x[i].qs;
|
||||
|
||||
for (int l = 0; l < QK8_0; l++) {
|
||||
const int8_t vi = pp[l];
|
||||
|
||||
y[i*QK8_0 + l] = vi*d;
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_0;
|
||||
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||
|
@ -151,6 +174,11 @@ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t st
|
|||
dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK8_0;
|
||||
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
// buffer pool for cuda
|
||||
#define MAX_CUDA_BUFFERS 16
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue