From 1213af76ceae9e839e1da440f95604c0a013d68d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 31 May 2023 22:28:59 +0300 Subject: [PATCH] mtl : add rope kernel --- examples/mtl/mtl.m | 75 +++++++++++++++++++++++++++++++++++++++++- examples/mtl/mtl.metal | 55 +++++++++++++++++++++++++++++++ llama.cpp | 24 +++++++++----- 3 files changed, 145 insertions(+), 9 deletions(-) diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index e447dfcf6..a114841dd 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -41,6 +41,9 @@ struct ggml_mtl_context { id function_mul_mat_q4_0; id pipeline_mul_mat_q4_0; + + id function_rope; + id pipeline_rope; }; // MSL code @@ -148,6 +151,10 @@ struct ggml_mtl_context * llama_mtl_init( ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"]; ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil]; fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0); + + ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"]; + ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil]; + fprintf(stderr, "%s: loaded kernel_rope: %p\n", __func__, (void *) ctx->pipeline_rope); } // MTLBuffer approach @@ -250,6 +257,10 @@ int llama_mtl_eval( fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); switch (gf->nodes[i]->op) { + case GGML_OP_RESHAPE: + { + // noop + } break; case GGML_OP_ADD: { if (encoder == nil) { @@ -453,6 +464,68 @@ int llama_mtl_eval( [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_ROPE: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); + id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + + const int64_t ne00 = gf->nodes[i]->src0->ne[0]; + const int64_t ne01 = gf->nodes[i]->src0->ne[1]; + const int64_t ne02 = gf->nodes[i]->src0->ne[2]; + const int64_t ne03 = gf->nodes[i]->src0->ne[3]; + + const uint64_t nb00 = gf->nodes[i]->src0->nb[0]; + const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; + const uint64_t nb02 = gf->nodes[i]->src0->nb[2]; + const uint64_t nb03 = gf->nodes[i]->src0->nb[3]; + + const int64_t ne0 = gf->nodes[i]->ne[0]; + const int64_t ne1 = gf->nodes[i]->ne[1]; + const int64_t ne2 = gf->nodes[i]->ne[2]; + const int64_t ne3 = gf->nodes[i]->ne[3]; + + const uint64_t nb0 = gf->nodes[i]->nb[0]; + const uint64_t nb1 = gf->nodes[i]->nb[1]; + const uint64_t nb2 = gf->nodes[i]->nb[2]; + const uint64_t nb3 = gf->nodes[i]->nb[3]; + + const int n_past = ((int32_t *) gf->nodes[i]->src1->data)[0]; // TODO: TMP !!!!! + const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1]; + const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2]; + + printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); + printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); + printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode); + + [encoder setComputePipelineState:ctx->pipeline_rope]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; + [encoder setBytes:&mode length:sizeof( int) atIndex:20]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); GGML_ASSERT(false); @@ -486,7 +559,7 @@ int llama_mtl_eval( { const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime]; - fprintf(stderr, "%s: time elapsed = %f\n", __func__, time_elapsed); + fprintf(stderr, "%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0); } // TODO diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 0cd93df7f..a46d016fb 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -210,3 +210,58 @@ kernel void kernel_mul_mat_q4_0( dst[r1*ne0 + r0] = sum[0]; } } + +kernel void kernel_rope( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + uint3 tpig[[thread_position_in_grid]]) { + const int64_t i3 = tpig[2]; + const int64_t i2 = tpig[1]; + const int64_t i1 = tpig[0]; + + const bool is_neox = mode & 2; + const float theta_scale = pow(10000.0, -2.0f/n_dims); + + const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + + float theta = (float)p; + + if (!is_neox) { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float cos_theta = cos(theta); + const float sin_theta = sin(theta); + + theta *= theta_scale; + + device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } + } else { + // TODO: implement + } +} diff --git a/llama.cpp b/llama.cpp index caf74bfd1..88cfe26ec 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1270,19 +1270,20 @@ static bool llama_eval_internal( // self-attention { - auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - // TODO: TMP !!!! - if (il == 0) { - ggml_set_name(x, "mtl-check"); - } + //auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + //struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0); // compute Q and K and RoPE them - //struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); - struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0); + struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); ggml_set_name(Qcur, "Qcur"); ggml_set_name(Kcur, "Kcur"); + // TODO: TMP !!!! + if (il == 0) { + ggml_set_name(Qcur, "mtl-check"); + } + // store key and value to memory { // compute the transposed [N, n_embd] V matrix @@ -1437,7 +1438,14 @@ static bool llama_eval_internal( //ggml_graph_compute (ctx0, &gf); // lets export a smaller graph to get things rolling -- baby steps first - ggml_build_forward_expand(&gf_export, ggml_get_tensor(ctx0, "mtl-check")); + { + struct ggml_tensor * t = ggml_get_tensor(ctx0, "mtl-check"); + if (!t) { + fprintf(stderr, "%s: failed to find tensor 'mtl-check'\n", __func__); + exit(1); + } + ggml_build_forward_expand(&gf_export, t); + } // print {