From 948fcfde7e74dc770687da9f0ea738195b782ac4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Jun 2023 19:21:28 +0300 Subject: [PATCH] mtl : add cpy kernel + handle view ops --- examples/mtl/mtl.m | 126 ++++++++++++++++++++++++++++++++++++----- examples/mtl/mtl.metal | 42 ++++++++++++++ llama.cpp | 44 +++++++++++--- 3 files changed, 191 insertions(+), 21 deletions(-) diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index 7e48e2b95..6d509a2ab 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -44,6 +44,9 @@ struct ggml_mtl_context { id function_rope; id pipeline_rope; + + id function_cpy_f32_f16; + id pipeline_cpy_f32_f16; }; // MSL code @@ -155,6 +158,10 @@ struct ggml_mtl_context * llama_mtl_init( ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"]; ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil]; fprintf(stderr, "%s: loaded kernel_rope: %p\n", __func__, (void *) ctx->pipeline_rope); + + ctx->function_cpy_f32_f16 = [ctx->library newFunctionWithName:@"kernel_cpy_f32_f16"]; + ctx->pipeline_cpy_f32_f16 = [ctx->device newComputePipelineStateWithFunction:ctx->function_cpy_f32_f16 error:nil]; + fprintf(stderr, "%s: loaded kernel_cpy_f32_f16: %p\n", __func__, (void *) ctx->pipeline_cpy_f32_f16); } // MTLBuffer approach @@ -258,6 +265,7 @@ int llama_mtl_eval( switch (gf->nodes[i]->op) { case GGML_OP_RESHAPE: + case GGML_OP_VIEW: case GGML_OP_TRANSPOSE: { // noop @@ -527,6 +535,76 @@ int llama_mtl_eval( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_CPY: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); + id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + + const int64_t ne00 = gf->nodes[i]->src0->ne[0]; + const int64_t ne01 = gf->nodes[i]->src0->ne[1]; + const int64_t ne02 = gf->nodes[i]->src0->ne[2]; + const int64_t ne03 = gf->nodes[i]->src0->ne[3]; + + const uint64_t nb00 = gf->nodes[i]->src0->nb[0]; + const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; + const uint64_t nb02 = gf->nodes[i]->src0->nb[2]; + const uint64_t nb03 = gf->nodes[i]->src0->nb[3]; + + const int64_t ne0 = gf->nodes[i]->ne[0]; + const int64_t ne1 = gf->nodes[i]->ne[1]; + const int64_t ne2 = gf->nodes[i]->ne[2]; + const int64_t ne3 = gf->nodes[i]->ne[3]; + + const uint64_t nb0 = gf->nodes[i]->nb[0]; + const uint64_t nb1 = gf->nodes[i]->nb[1]; + const uint64_t nb2 = gf->nodes[i]->nb[2]; + const uint64_t nb3 = gf->nodes[i]->nb[3]; + + const enum ggml_type src0t = gf->nodes[i]->src0->type; + const enum ggml_type dstt = gf->nodes[i]->type; + + printf("cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); + printf("cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03); + printf("cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); + printf("cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3); + printf("cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt)); + + switch (src0t) { + case GGML_TYPE_F32: + { + switch (dstt) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; + default: GGML_ASSERT(false && "not implemented"); + }; + } break; + default: GGML_ASSERT(false && "not implemented"); + } + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + } break; default: fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); GGML_ASSERT(false); @@ -568,21 +646,41 @@ int llama_mtl_eval( { struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check"); - float * data = (float *) ctx->out.contents; - printf("data: "); - int n = t->ne[0]; - if (n > 10) { - n = 10; + if (t->type == GGML_TYPE_F32) { + const const float * data = (float *) ctx->out.contents; + printf("data: "); + int n = ggml_nelements(t); + if (n > 10) { + n = 10; + } + for (int i = 0; i < n; i++) { + printf("%f ", data[i]); + } + printf("\n"); + double sum = 0.0; + for (int i = 0; i < ggml_nelements(t); i++) { + sum += data[i]; + } + printf("sum: %f\n", sum); + } else if (t->type == GGML_TYPE_F16) { + const ggml_fp16_t * data = (const ggml_fp16_t *) ctx->out.contents; + printf("data: "); + int n = ggml_nelements(t); + if (n > 10) { + n = 10; + } + for (int i = 0; i < n; i++) { + printf("%f ", ggml_fp16_to_fp32(data[i])); + } + printf("\n"); + double sum = 0.0; + for (int i = 0; i < ggml_nelements(t); i++) { + sum += ggml_fp16_to_fp32(data[i]); + } + printf("sum: %f\n", sum); + } else { + GGML_ASSERT(false && "not implemented"); } - for (int i = 0; i < n; i++) { - printf("%f ", data[i]); - } - printf("\n"); - double sum = 0.0; - for (int i = 0; i < ggml_nelements(t); i++) { - sum += data[i]; - } - printf("sum: %f\n", sum); } return 0; diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index a46d016fb..7e5c3aad4 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -265,3 +265,45 @@ kernel void kernel_rope( // TODO: implement } } + +kernel void kernel_cpy_f32_f16( + device const float * src0, + device half * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} diff --git a/llama.cpp b/llama.cpp index fdbbca69f..5e7c3db86 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1283,18 +1283,21 @@ static bool llama_eval_internal( { // compute the transposed [N, n_embd] V matrix struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N)); - // TODO: TMP !!!! - if (il == 0) { - ggml_set_name(Vcur, "mtl-check"); - } struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, ( n_ctx)*ggml_element_size(kv_self.v), (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); + struct ggml_tensor * t = ggml_cpy(ctx0, Kcur, k); + // TODO: TMP !!!! + if (il == 0) { + ggml_set_name(t, "mtl-check"); + } + // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + //ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, t); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); } @@ -1448,7 +1451,7 @@ static bool llama_eval_internal( // print { - auto print_t = [&](struct ggml_tensor * t) { + auto print_t_f32 = [&](struct ggml_tensor * t) { float * data = (float *)t->data; printf("data: "); for (int i = 0; i < std::min((int) t->ne[0], 10); i++) { @@ -1461,9 +1464,36 @@ static bool llama_eval_internal( } printf("sum: %f\n", sum); }; + auto print_t_f16 = [&](struct ggml_tensor * t) { + ggml_fp16_t * data = (ggml_fp16_t *)t->data; + printf("data: "); + for (int i = 0; i < std::min((int) t->ne[0], 10); i++) { + printf("%f ", ggml_fp16_to_fp32(data[i])); + } + printf("\n"); + double sum = 0.0; + for (int i = 0; i < ggml_nelements(t); i++) { + sum += ggml_fp16_to_fp32(data[i]); + } + printf("sum: %f\n", sum); + }; ggml_graph_compute(ctx0, &gf_export); - print_t(ggml_get_tensor(ctx0, "mtl-check")); + + { + auto * t = ggml_get_tensor(ctx0, "mtl-check"); + switch (t->type) { + case GGML_TYPE_F32: + print_t_f32(t); + break; + case GGML_TYPE_F16: + print_t_f16(t); + break; + default: + fprintf(stderr, "%s: unsupported type\n", __func__); + exit(1); + } + } } if (cgraph_fname) {