From 2a24994badb709c5833c1126974af4d3677a4f06 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 May 2023 22:02:54 +0300 Subject: [PATCH] mtl : initial mul_mat Q4 kernel (wrong results) --- examples/mtl/mtl.m | 48 ++++++++++++++++++- examples/mtl/mtl.metal | 102 +++++++++++++++++++++++++++++++++++------ llama.cpp | 13 ++++-- 3 files changed, 144 insertions(+), 19 deletions(-) diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index f13c00776..bd424c23d 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -38,6 +38,9 @@ struct ggml_mtl_context { id function_rms_norm; id pipeline_rms_norm; + + id function_mul_mat_q4_0; + id pipeline_mul_mat_q4_0; }; // MSL code @@ -141,6 +144,10 @@ struct ggml_mtl_context * llama_mtl_init( ctx->function_rms_norm = [ctx->library newFunctionWithName:@"kernel_rms_norm"]; ctx->pipeline_rms_norm = [ctx->device newComputePipelineStateWithFunction:ctx->function_rms_norm error:nil]; fprintf(stderr, "%s: loaded kernel_rms_norm: %p\n", __func__, (void *) ctx->pipeline_rms_norm); + + ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"]; + ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil]; + fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0); } // MTLBuffer approach @@ -317,7 +324,9 @@ int llama_mtl_eval( [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_MUL_MAT: - { + if (gf->nodes[i]->src0->type == GGML_TYPE_F32) { + // for F32 x F32 we use MPS + if (encoder != nil) { [encoder endEncoding]; encoder = nil; @@ -354,6 +363,43 @@ int llama_mtl_eval( transposeLeft:false transposeRight:true resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols0 alpha:1.0 beta:0.0]; [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst]; + } else { + // for Q4 x F32 we use custom kernel + + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + GGML_ASSERT(gf->nodes[i]->src0->ne[2] == 1); + GGML_ASSERT(gf->nodes[i]->src1->ne[2] == 1); + + id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); + id id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1); + id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + + const int64_t ncols0 = gf->nodes[i]->src0->ne[0]; + const int64_t nrows0 = gf->nodes[i]->src0->ne[1]; + + const int64_t ncols1 = gf->nodes[i]->src1->ne[0]; + const int64_t nrows1 = gf->nodes[i]->src1->ne[1]; + + const int64_t ncols = gf->nodes[i]->ne[0]; + const int64_t nrows = gf->nodes[i]->ne[1]; + + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ncols0 length:sizeof(ncols0) atIndex:3]; + [encoder setBytes:&nrows0 length:sizeof(nrows0) atIndex:4]; + [encoder setBytes:&ncols1 length:sizeof(ncols1) atIndex:5]; + [encoder setBytes:&nrows1 length:sizeof(nrows1) atIndex:6]; + [encoder setBytes:&ncols length:sizeof(ncols) atIndex:7]; + [encoder setBytes:&nrows length:sizeof(nrows) atIndex:8]; + + printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ncols0, nrows0, ncols1, nrows1, ncols, nrows); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows0, nrows1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_GET_ROWS: { diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 78dfbe011..f67d24f71 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -7,8 +7,8 @@ using namespace metal; #define QK4_0 32 #define QR4_0 2 typedef struct { - half d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) { @@ -38,8 +38,8 @@ kernel void kernel_add( device const float * src0, device const float * src1, device float * dst, - uint gid[[thread_position_in_grid]]) { - dst[gid] = src0[gid] + src1[gid]; + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig]; } // assumption: src1 is a row @@ -49,15 +49,15 @@ kernel void kernel_mul( device const float * src1, device float * dst, constant int64_t & ne00, - uint gid[[thread_position_in_grid]]) { - dst[gid] = src0[gid] * src1[gid % ne00]; + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src1[tpig % ne00]; } kernel void kernel_relu( device const float * src0, device float * dst, - uint gid[[thread_position_in_grid]]) { - dst[gid] = max(0.0f, src0[gid]); + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); } // TODO: broken @@ -85,8 +85,8 @@ kernel void kernel_get_rows_q4_0( constant int64_t & ne00, constant uint64_t & nb01, constant uint64_t & nb1, - uint gid[[thread_position_in_grid]]) { - const int i = gid; + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; const int r = ((device int32_t *) src1)[i]; dequantize_row_q4_0( @@ -100,8 +100,8 @@ kernel void kernel_rms_norm( constant int64_t & ne00, constant uint64_t & nb01, constant float & eps, - uint gid[[thread_position_in_grid]]) { - device const float * x = (device const float *) ((device const char *) src0 + gid*nb01); + uint tpig[[thread_position_in_grid]]) { + device const float * x = (device const float *) ((device const char *) src0 + tpig*nb01); float sum = 0.0f; for (int i00 = 0; i00 < ne00; i00++) { @@ -111,8 +111,84 @@ kernel void kernel_rms_norm( const float mean = sum/ne00; const float scale = 1.0f/sqrt(mean + eps); - device float * y = dst + gid*ne00; + device float * y = dst + tpig*ne00; for (int i00 = 0; i00 < ne00; i00++) { y[i00] = x[i00] * scale; } } + +kernel void kernel_mul_mat_q4_0( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne0, + constant int64_t & ne1, + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + const int qk = QK4_0; + const int nb = ne00/qk; + + device const block_q4_0 * x = (device const block_q4_0 *) (src0) + r0*nb; + device const float * y = (device const float *) (src1) + r1*ne10; + + threadgroup float sum[32]; // TODO: should be equal to threadgroup size + sum[tpitg.x] = 0.0f; + + for (int i = 0; i < nb; i += tptg.x) { + device const uint4 * x0p = (device const uint4 *) (x + i); + device const float4 * y0p = (device const float4 *) (y + i*qk); + + const uint4 x0 = *x0p; + + const uint4 x0l = x0 & uint4(0x0F0F0F0F); + const uint4 x0h = x0 >> 4; + + const int4 x0ls = as_type(x0l) - int4(8); + const int4 x0hs = as_type(x0h) - int4(8); + + thread const uchar * x0lsb = (thread const uchar *) &x0ls; + thread const uchar * x0hsb = (thread const uchar *) &x0hs; + + const float4 y00 = *(y0p + 0); + const float4 y01 = *(y0p + 1); + const float4 y02 = *(y0p + 2); + const float4 y03 = *(y0p + 3); + const float4 y04 = *(y0p + 4); + const float4 y05 = *(y0p + 5); + const float4 y06 = *(y0p + 6); + const float4 y07 = *(y0p + 7); + + const float d = (x + i)->d; + + sum[tpitg.x] += ( + x0lsb[ 0]*y00[0] + x0lsb[ 1]*y00[1] + x0lsb[ 2]*y00[2] + x0lsb[ 3]*y00[3] + + x0lsb[ 4]*y01[0] + x0lsb[ 5]*y01[1] + x0lsb[ 6]*y01[2] + x0lsb[ 7]*y01[3] + + x0lsb[ 8]*y02[0] + x0lsb[ 9]*y02[1] + x0lsb[10]*y02[2] + x0lsb[11]*y02[3] + + x0lsb[12]*y03[0] + x0lsb[13]*y03[1] + x0lsb[14]*y03[2] + x0lsb[15]*y03[3] + + x0hsb[ 0]*y04[0] + x0hsb[ 1]*y04[1] + x0hsb[ 2]*y04[2] + x0hsb[ 3]*y04[3] + + x0hsb[ 4]*y05[0] + x0hsb[ 5]*y05[1] + x0hsb[ 6]*y05[2] + x0hsb[ 7]*y05[3] + + x0hsb[ 8]*y06[0] + x0hsb[ 9]*y06[1] + x0hsb[10]*y06[2] + x0hsb[11]*y06[3] + + x0hsb[12]*y07[0] + x0hsb[13]*y07[1] + x0hsb[14]*y07[2] + x0hsb[15]*y07[3] + ) * d; + } + + // accumulate the sum from all threads in the threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = tptg.x/2; i > 0; i /= 2) { + if (tpitg.x < i) { + sum[tpitg.x] += sum[tpitg.x + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + dst[r1*ne0 + r0] = sum[0]; +} diff --git a/llama.cpp b/llama.cpp index 3ddfeff01..caf74bfd1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1266,16 +1266,19 @@ static bool llama_eval_internal( // cur = cur*attention_norm(broadcasted) cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm); - // TODO: TMP !!!! - if (il == 0) { - ggml_set_name(cur, "mtl-check"); - } } // self-attention { + auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + // TODO: TMP !!!! + if (il == 0) { + ggml_set_name(x, "mtl-check"); + } + // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); + //struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); + struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0); struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); ggml_set_name(Qcur, "Qcur"); ggml_set_name(Kcur, "Kcur");