From 29bec00ba06eb25fcd5948ca63f200780816ff1e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 May 2023 22:31:07 +0300 Subject: [PATCH] mtl : another mul_mat Q4 (still does not work) --- examples/mtl/mtl.metal | 75 ++++++++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 348a432ab..7eb9259e2 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -137,45 +137,64 @@ kernel void kernel_mul_mat_q4_0( const int qk = QK4_0; const int nb = ne00/qk; - device const block_q4_0 * x = (device const block_q4_0 *) (src0) + r0*nb; - device const float * y = (device const float *) (src1) + r1*ne10; + device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb; + device const float * y = (device const float *) src1 + r1*ne10; threadgroup float sum[32]; // TODO: should be equal to threadgroup size sum[tpitg.x] = 0.0f; for (int i = 0; i < nb; i += tptg.x) { - device const uint4 * x0p = (device const uint4 *) (x + i)->qs; - device const float4 * y0p = (device const float4 *) (y + i*qk); + //device const uint4 * x0p = (device const uint4 *) (x + i)->qs; + //device const float4 * y0p = (device const float4 *) (y + i*qk); - const uint4 x0 = *x0p; + //const uint4 x0 = *x0p; - const uint4 x0l = (x0 & uint4(0x0F0F0F0F)); - const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4; + //const uint4 x0l = (x0 & uint4(0x0F0F0F0F)); + //const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4; - thread const char * x0lsb = (thread const char *) &x0l; - thread const char * x0hsb = (thread const char *) &x0h; + //thread const char * x0lsb = (thread const char *) &x0l; + //thread const char * x0hsb = (thread const char *) &x0h; - const float4 y00 = *(y0p + 0); - const float4 y01 = *(y0p + 1); - const float4 y02 = *(y0p + 2); - const float4 y03 = *(y0p + 3); - const float4 y04 = *(y0p + 4); - const float4 y05 = *(y0p + 5); - const float4 y06 = *(y0p + 6); - const float4 y07 = *(y0p + 7); + //const float4 y00 = *(y0p + 0); + //const float4 y01 = *(y0p + 1); + //const float4 y02 = *(y0p + 2); + //const float4 y03 = *(y0p + 3); + //const float4 y04 = *(y0p + 4); + //const float4 y05 = *(y0p + 5); + //const float4 y06 = *(y0p + 6); + //const float4 y07 = *(y0p + 7); - const half d = (x + i)->d; + //const half d = (x + i)->d; - sum[tpitg.x] += ( - (x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] + - (x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] + - (x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] + - (x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] + - (x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] + - (x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] + - (x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] + - (x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3] - ) * d; + //sum[tpitg.x] += ( + // (x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] + + // (x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] + + // (x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] + + // (x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] + + // (x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] + + // (x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] + + // (x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] + + // (x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3] + // ) * d; + + device const uchar * x0p = (device const uchar *) (x + i)->qs; + device const float * y0p = (device const float *) (y + i*qk); + + float acc = 0.0f; + + for (int j = 0; j < 16; ++j) { + const uchar x0v = *(x0p + j); + + const int x0 = x0v & 0x0F; + const int x1 = x0v >> 4; + + const float y0 = *(y0p + j); + const float y1 = *(y0p + j + 16); + + acc += (x0 - 8)*y0 + (x1 - 8)*y1; + } + + sum[tpitg.x] += acc * (x + i)->d; } // accumulate the sum from all threads in the threadgroup