From 1202e06c6f268660e85846eb51912122292707f4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 24 Aug 2023 15:42:29 +0300 Subject: [PATCH] metal : add Q8_0 mul_mm kernel --- ggml-metal.m | 5 ++++- ggml-metal.metal | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml-metal.m b/ggml-metal.m index 358a51b74..06eb3872e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -83,6 +83,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); @@ -209,6 +210,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); @@ -751,9 +753,10 @@ void ggml_metal_graph_compute( ne00%32 == 0 && ne11 > 1) { switch (src0->type) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; + case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 83a9d86f8..82e1a0c7a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2041,6 +2041,7 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm;