From f12863ed38f770d25b7c3ba0c8b35a4d680c267c Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Fri, 29 Sep 2023 18:20:19 -0400 Subject: [PATCH] metal : fix hardcoded constants in mul_vec_q_n_f32 --- ggml-metal.metal | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 5e1af6a09..ea04162b0 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -435,31 +435,31 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device const uint offset0 = first_row * nb + im/gqa*(nb*ne0); device const block_q_type * x = (device const block_q_type *) src0 + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[16]; // src1 vector cache + float yl[QK4_0/2]; // src1 vector cache float sumf[nr]={0.f}; const int ix = tiisg/2; - const int il = 8*(tiisg%2); + const int il = QK4_0/4*(tiisg%2); device const float * yb = y + ix * QK4_0 + il; // each thread in a SIMD group deals with half a block. for (int ib = ix; ib < nb; ib += nw/2) { float sumy = 0; - for (int i = 0; i < 8; i += 2) { + for (int i = 0; i < QK4_0/4; i += 2) { sumy += yb[i] + yb[i+1]; yl[i+0] = yb[i+ 0]; yl[i+1] = yb[i+ 1]/256.f; - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; + sumy += yb[i+QK4_0/2] + yb[i+QK4_0/2+1]; + yl[i+QK4_0/4] = yb[i+QK4_0/2] /16.f; + yl[i+QK4_0/4+1] = yb[i+QK4_0/2+1]/4096.f; } for (int row = 0; row < nr; row++) { sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); } - yb += QK4_0 * 16; + yb += QK4_0 * nw/2; } for (int row = 0; row < nr; ++row) {