From 46a0881c7fcf84b43f22f1219bc529445f58521d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 24 Aug 2023 13:40:34 +0300 Subject: [PATCH] metal : add dequantize_q8_0 kernel --- ggml-metal.m | 5 ++++- ggml-metal.metal | 27 ++++++++++++++++++++++----- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 969cf7daa..c0996f2c0 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -63,6 +63,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); + GGML_METAL_DECL_KERNEL(get_rows_q8_0); GGML_METAL_DECL_KERNEL(get_rows_q2_K); GGML_METAL_DECL_KERNEL(get_rows_q3_K); GGML_METAL_DECL_KERNEL(get_rows_q4_K); @@ -188,6 +189,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); + GGML_METAL_ADD_KERNEL(get_rows_q8_0); GGML_METAL_ADD_KERNEL(get_rows_q2_K); GGML_METAL_ADD_KERNEL(get_rows_q3_K); GGML_METAL_ADD_KERNEL(get_rows_q4_K); @@ -896,9 +898,10 @@ void ggml_metal_graph_compute( case GGML_OP_GET_ROWS: { switch (src0->type) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; + case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 7bc3fdf37..c66bf912d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -18,6 +18,12 @@ typedef struct { uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; +#define QK8_0 32 +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; + kernel void kernel_add( device const float * src0, device const float * src1, @@ -1621,12 +1627,12 @@ template void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); const half d = il ? (xb->d / 16.h) : xb->d; - const half m = il ? (-8.h * 16.h) : -8.h; + const half m = il ? ( -8.h * 16.h) : -8.h; const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask1 = il ? 0xF000 : 0x0F00; for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d; + reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d; reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d; } } @@ -1640,11 +1646,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg const ushort mask1 = il ? 0xF000 : 0x0F00; for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m; + reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m; reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m; } } +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const uint8_t * qs = ((device const uint8_t *)xb->qs); + const half d = xb->d; + + for (int i=0;i<16;i++) { + reg[i/4][i%4] = (qs[i + 16*il] * d); + } +} + template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { const half d = xb->d; @@ -1947,9 +1963,10 @@ kernel void kernel_mul_mm(device const uchar * src0, typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ constant uint64_t &, constant uint64_t &, uint, uint, uint); -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; @@ -1960,7 +1977,7 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm;