From 61c8259a8841d57dff375fbc733d94162ec473e6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 24 Aug 2023 15:32:27 +0300 Subject: [PATCH] metal : add mul_mat_q8_0_f32 kernel --- ggml-metal.m | 13 +++++++++- ggml-metal.metal | 67 +++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 76 insertions(+), 4 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index c0996f2c0..358a51b74 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -74,6 +74,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); @@ -200,6 +201,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); @@ -802,6 +804,15 @@ void ggml_metal_graph_compute( nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32]; + } break; case GGML_TYPE_Q2_K: { GGML_ASSERT(ne02 == 1); @@ -873,7 +884,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } diff --git a/ggml-metal.metal b/ggml-metal.metal index c66bf912d..83a9d86f8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -363,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device const int first_row = (r0 * nsg + sgitg) * nr; const uint offset0 = first_row * nb + im/gqa*(nb*ne0); device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; float yl[16]; // src1 vector cache float sumf[nr]={0.f}; @@ -435,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32( mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } +kernel void kernel_mul_mat_q8_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + + const int nb = ne00/QK8_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; + float sumf[nr]={0.f}; + + const int ix = tiisg/2; + const int il = tiisg%2; + + device const float * yb = y + ix * QK8_0 + 16*il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + for (int i = 0; i < 16; ++i) { + yl[i] = yb[i]; + } + + for (int row = 0; row < nr; row++) { + device const int8_t * qs = x[ib+row*nb].qs + 16*il; + float sumq = 0.f; + for (int iq = 0; iq < 16; ++iq) { + sumq += qs[iq] * yl[iq]; + } + sumf[row] += sumq*x[ib+row*nb].d; + } + + yb += QK8_0 * 16; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} + kernel void kernel_mul_mat_f16_f32( device const char * src0, device const char * src1, @@ -486,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32( } } - kernel void kernel_alibi_f32( device const float * src0, device float * dst, @@ -1653,7 +1714,7 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg template void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { - device const uint8_t * qs = ((device const uint8_t *)xb->qs); + device const int8_t * qs = ((device const int8_t *)xb->qs); const half d = xb->d; for (int i=0;i<16;i++) {