diff --git a/examples/mtl/mtl.cpp b/examples/mtl/mtl.cpp index 40e8fbcee..7f52453d8 100644 --- a/examples/mtl/mtl.cpp +++ b/examples/mtl/mtl.cpp @@ -41,13 +41,15 @@ int main(int argc, char ** argv) { // TODO: tmp to match the input used when creating the cgraph { - const int n_past = 128; - const int n_batch = 32; + const int n_batch = 1; + const int n_past = 512 - n_batch; const std::vector tmp(n_batch, 1); // BOS // the actual inference happens here - llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past); + for (int i = 0; i < 10; ++i) { + llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past); + } } llama_mtl_free(ctx_mtl); diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index 85003ebdd..ff1adf6df 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -429,14 +429,17 @@ int llama_mtl_eval( const int64_t ne02 = gf->nodes[i]->src0->ne[2]; const int64_t ne03 = gf->nodes[i]->src0->ne[3]; + const int nth = 32; + [encoder setComputePipelineState:ctx->pipeline_soft_max]; [encoder setBuffer:id_src offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_DIAG_MASK_INF: { @@ -494,10 +497,10 @@ int llama_mtl_eval( const enum ggml_type src1t = gf->nodes[i]->src1->type; const enum ggml_type dstt = gf->nodes[i]->type; - printf("mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02); - printf("mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12); - printf("mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2); - printf("mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt)); + fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02); + fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12); + fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2); + fprintf(stderr, "mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt)); GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne02 == ne12); @@ -599,16 +602,19 @@ int llama_mtl_eval( const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; const float eps = 1e-6f; + const int nth = 32; + [encoder setComputePipelineState:ctx->pipeline_rms_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(gf->nodes[i]->src0); - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ROPE: { @@ -643,9 +649,9 @@ int llama_mtl_eval( const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1]; const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2]; - printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); - printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); - printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode); + fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); + fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); + fprintf(stderr, "rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode); [encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -704,11 +710,13 @@ int llama_mtl_eval( const enum ggml_type src0t = gf->nodes[i]->src0->type; const enum ggml_type dstt = gf->nodes[i]->type; - printf("cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); - printf("cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03); - printf("cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); - printf("cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3); - printf("cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt)); + const int nth = 32; + + fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); + fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03); + fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); + fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3); + fprintf(stderr, "cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt)); switch (src0t) { case GGML_TYPE_F32: @@ -741,7 +749,7 @@ int llama_mtl_eval( [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; default: fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); @@ -764,8 +772,6 @@ int llama_mtl_eval( id id_src = llama_mtl_get_buffer(ctx, out, &offs_src0); id id_dst = ctx->out; - printf("XXXXX n = %d\n", ggml_nelements(out)); - id encoder_blit = [command_buffer blitCommandEncoder]; [encoder_blit copyFromBuffer:id_src sourceOffset:offs_src0 toBuffer:id_dst destinationOffset:0 size:ggml_nbytes(out)]; [encoder_blit endEncoding]; @@ -776,12 +782,29 @@ int llama_mtl_eval( { const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime]; - fprintf(stderr, "%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0); + printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0); } // TODO const float * logits = ctx->out.contents; + printf("logits: "); + for (int i = 0; i < 100; i++) { + printf("%8.4f ", logits[i]); + } + printf("\n"); + double sum = 0.0; + int imax = 0; + double vmax = -INFINITY; + for (int i = 0; i < 32000; i++) { + sum += (double) logits[i]; + if (logits[i] > vmax) { + vmax = logits[i]; + imax = i; + } + } + printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax); + //{ // struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check"); // if (t->type == GGML_TYPE_F32) { diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 9ab51963f..f8446d17f 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -87,25 +87,80 @@ kernel void kernel_soft_max( constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, - uint3 tpig[[thread_position_in_grid]]) { - const int64_t i03 = tpig[2]; - const int64_t i02 = tpig[1]; - const int64_t i01 = tpig[0]; + threadgroup float * buf [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - float max = 0.0f; - for (int i = 0; i < ne00; i++) { - max = MAX(max, psrc0[i]); + //float max = 0.0f; + //for (int i = 0; i < ne00; i++) { + // max = MAX(max, psrc0[i]); + //} + //float sum = 0.0f; + //for (int i = 0; i < ne00; i++) { + // pdst[i] = exp(psrc0[i] - max); + // sum += pdst[i]; + //} + //for (int i = 0; i < ne00; i++) { + // pdst[i] /= sum; + //} + + // parallel max + buf[tpitg[0]] = -INFINITY; + for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]); } - float sum = 0.0f; - for (int i = 0; i < ne00; i++) { - pdst[i] = exp(psrc0[i] - max); - sum += pdst[i]; + + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg[0]/2; i > 0; i /= 2) { + if (tpitg[0] < i) { + buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); } - for (int i = 0; i < ne00; i++) { - pdst[i] /= sum; + + // broadcast + if (tpitg[0] == 0) { + buf[0] = buf[0]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float max = buf[0]; + + // parallel sum + buf[tpitg[0]] = 0.0f; + for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + buf[tpitg[0]] += exp(psrc0[i00] - max); + } + + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg[0]/2; i > 0; i /= 2) { + if (tpitg[0] < i) { + buf[tpitg[0]] += buf[tpitg[0] + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // broadcast + if (tpitg[0] == 0) { + buf[0] = buf[0]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float sum = buf[0]; + + for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + pdst[i00] = exp(psrc0[i00] - max) / sum; } } @@ -149,19 +204,39 @@ kernel void kernel_rms_norm( constant int64_t & ne00, constant uint64_t & nb01, constant float & eps, - uint tpig[[thread_position_in_grid]]) { - device const float * x = (device const float *) ((device const char *) src0 + tpig*nb01); + threadgroup float * sum [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); - float sum = 0.0f; - for (int i00 = 0; i00 < ne00; i00++) { - sum += x[i00] * x[i00]; + // parallel sum + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + sum[tpitg] += x[i00] * x[i00]; } - const float mean = sum/ne00; + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // broadcast + if (tpitg == 0) { + sum[0] /= ne00; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float mean = sum[0]; const float scale = 1.0f/sqrt(mean + eps); - device float * y = dst + tpig*ne00; - for (int i00 = 0; i00 < ne00; i00++) { + device float * y = dst + tgpig*ne00; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { y[i00] = x[i00] * scale; } } diff --git a/ggml.c b/ggml.c index 330a896ca..1c9bb4e61 100644 --- a/ggml.c +++ b/ggml.c @@ -14647,8 +14647,8 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char } void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { - assert(cgraph->work == NULL); - assert(cgraph->work_size == 0); + //assert(cgraph->work == NULL); + //assert(cgraph->work_size == 0); uint64_t size_eval = 0; diff --git a/llama.cpp b/llama.cpp index e0fbc6f73..c998a77fb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1506,6 +1506,25 @@ static bool llama_eval_internal( if (cgraph_fname) { ggml_graph_export(&gf, cgraph_fname); + + float * logits = (float *) ggml_get_data(inpL); + + printf("logits: "); + for (int i = 0; i < 10; i++) { + printf("%8.4f ", logits[i]); + } + printf("\n"); + double sum = 0.0; + int imax = 0; + double vmax = -INFINITY; + for (int i = 0; i < 32000; i++) { + sum += (double) logits[i]; + if (logits[i] > vmax) { + vmax = logits[i]; + imax = i; + } + } + printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax); } #ifdef GGML_PERF @@ -3002,11 +3021,11 @@ int llama_eval( int llama_eval_export(struct llama_context * ctx, const char * fname) { // these values determine the maximum inference sizes of the exported computation graph - // TODO: TMP !!! + // TODO: need to increase buffers to support the full context //const int n_ctx = ctx->model.hparams.n_ctx; //const int n_batch = 512; - const int n_ctx = 128; - const int n_batch = 32; + const int n_batch = 1; + const int n_ctx = 512 - n_batch; const std::vector tmp(n_batch, llama_token_bos());