diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index f452979c4..89ed45c01 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -257,7 +257,7 @@ id llama_mtl_get_buffer(struct ggml_mtl_context * ctx, struct ggml_te const bool is_data = (offs_eval < 0) || (offs_data >= 0 && offs_data < offs_eval); - const size_t t_size = ggml_nbytes(t); + //const size_t t_size = ggml_nbytes(t); const size_t t_offs = is_data ? offs_data : offs_eval; id result; @@ -271,7 +271,7 @@ id llama_mtl_get_buffer(struct ggml_mtl_context * ctx, struct ggml_te } if (result == nil) { - //fprintf(stderr, "%s: error: buffer is nil\n", __func__); + fprintf(stderr, "%s: error: buffer is nil\n", __func__); GGML_ASSERT(false); } @@ -296,9 +296,9 @@ int llama_mtl_eval( id command_buffer = [ctx->queue commandBuffer]; id encoder = nil; - size_t offs_src0; - size_t offs_src1; - size_t offs_dst; + size_t offs_src0 = 0; + size_t offs_src1 = 0; + size_t offs_dst = 0; // copy the input data to the GPU { @@ -312,6 +312,48 @@ int llama_mtl_eval( for (int i = 0; i < gf->n_nodes; ++i) { //fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + struct ggml_tensor * src0 = gf->nodes[i]->src0; + struct ggml_tensor * src1 = gf->nodes[i]->src1; + struct ggml_tensor * dst = gf->nodes[i]; + + const int64_t ne00 = src0 ? src0->ne[0] : 0; + const int64_t ne01 = src0 ? src0->ne[1] : 0; + const int64_t ne02 = src0 ? src0->ne[2] : 0; + const int64_t ne03 = src0 ? src0->ne[3] : 0; + + const uint64_t nb00 = src0 ? src0->nb[0] : 0; + const uint64_t nb01 = src0 ? src0->nb[1] : 0; + const uint64_t nb02 = src0 ? src0->nb[2] : 0; + const uint64_t nb03 = src0 ? src0->nb[3] : 0; + + const int64_t ne10 = src1 ? src1->ne[0] : 0; + const int64_t ne11 = src1 ? src1->ne[1] : 0; + const int64_t ne12 = src1 ? src1->ne[2] : 0; + //const int64_t ne13 = src1 ? src1->ne[3] : 0; + + const uint64_t nb10 = src1 ? src1->nb[0] : 0; + const uint64_t nb11 = src1 ? src1->nb[1] : 0; + const uint64_t nb12 = src1 ? src1->nb[2] : 0; + //const uint64_t nb13 = src1 ? src1->nb[3] : 0; + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; + + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; + + id id_src0 = src0 ? llama_mtl_get_buffer(ctx, src0, &offs_src0) : nil; + id id_src1 = src1 ? llama_mtl_get_buffer(ctx, src1, &offs_src1) : nil; + id id_dst = dst ? llama_mtl_get_buffer(ctx, dst, &offs_dst) : nil; + switch (gf->nodes[i]->op) { case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -326,10 +368,6 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); - id id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1); - id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - [encoder setComputePipelineState:ctx->pipeline_add]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -345,14 +383,6 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); - id id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1); - id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - - const int64_t ne00 = gf->nodes[i]->src0->ne[0]; - - const int64_t ne10 = gf->nodes[i]->src1->ne[0]; - if (ggml_nelements(gf->nodes[i]->src1) == ne10) { // src1 is a row [encoder setComputePipelineState:ctx->pipeline_mul_row]; @@ -374,9 +404,6 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); - id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - const float scale = *(const float *) gf->nodes[i]->src1->data; [encoder setComputePipelineState:ctx->pipeline_scale]; @@ -394,9 +421,6 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); - id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - [encoder setComputePipelineState:ctx->pipeline_silu]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -411,12 +435,9 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - id id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); - id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - [encoder setComputePipelineState:ctx->pipeline_relu]; - [encoder setBuffer:id_src offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; const int64_t n = ggml_nelements(gf->nodes[i]); @@ -428,19 +449,11 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - id id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); - id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - - const int64_t ne00 = gf->nodes[i]->src0->ne[0]; - const int64_t ne01 = gf->nodes[i]->src0->ne[1]; - const int64_t ne02 = gf->nodes[i]->src0->ne[2]; - const int64_t ne03 = gf->nodes[i]->src0->ne[3]; - const int nth = 32; [encoder setComputePipelineState:ctx->pipeline_soft_max]; - [encoder setBuffer:id_src offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; @@ -454,16 +467,9 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - id id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); - id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - - const int64_t ne00 = gf->nodes[i]->src0->ne[0]; - const int64_t ne01 = gf->nodes[i]->src0->ne[1]; - const int64_t ne02 = gf->nodes[i]->src0->ne[2]; - [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; - [encoder setBuffer:id_src offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; @@ -472,38 +478,6 @@ int llama_mtl_eval( } break; case GGML_OP_MUL_MAT: { - id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); - id id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1); - id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - - const int64_t ne00 = gf->nodes[i]->src0->ne[0]; - const int64_t ne01 = gf->nodes[i]->src0->ne[1]; - const int64_t ne02 = gf->nodes[i]->src0->ne[2]; - - const uint64_t nb00 = gf->nodes[i]->src0->nb[0]; - const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; - const uint64_t nb02 = gf->nodes[i]->src0->nb[2]; - - const int64_t ne10 = gf->nodes[i]->src1->ne[0]; - const int64_t ne11 = gf->nodes[i]->src1->ne[1]; - const int64_t ne12 = gf->nodes[i]->src1->ne[2]; - - const uint64_t nb10 = gf->nodes[i]->src1->nb[0]; - const uint64_t nb11 = gf->nodes[i]->src1->nb[1]; - const uint64_t nb12 = gf->nodes[i]->src1->nb[2]; - - const int64_t ne0 = gf->nodes[i]->ne[0]; - const int64_t ne1 = gf->nodes[i]->ne[1]; - const int64_t ne2 = gf->nodes[i]->ne[2]; - - const uint64_t nb0 = gf->nodes[i]->nb[0]; - const uint64_t nb1 = gf->nodes[i]->nb[1]; - const uint64_t nb2 = gf->nodes[i]->nb[2]; - - const enum ggml_type src0t = gf->nodes[i]->src0->type; - const enum ggml_type src1t = gf->nodes[i]->src1->type; - const enum ggml_type dstt = gf->nodes[i]->type; - //fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src0t), ne00, ne01, ne02, ggml_is_contiguous(gf->nodes[i]->src0)); //fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src1t), ne10, ne11, ne12, ggml_is_contiguous(gf->nodes[i]->src1)); //fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2); @@ -613,10 +587,6 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); - id id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1); - id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - switch (gf->nodes[i]->src0->type) { case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; default: { @@ -642,12 +612,7 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); - id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - - const int64_t ne00 = gf->nodes[i]->src0->ne[0]; - const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; - const float eps = 1e-6f; + const float eps = 1e-6f; const int nth = 256; @@ -669,30 +634,6 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); - id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - - const int64_t ne00 = gf->nodes[i]->src0->ne[0]; - const int64_t ne01 = gf->nodes[i]->src0->ne[1]; - const int64_t ne02 = gf->nodes[i]->src0->ne[2]; - const int64_t ne03 = gf->nodes[i]->src0->ne[3]; - - const uint64_t nb00 = gf->nodes[i]->src0->nb[0]; - const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; - const uint64_t nb02 = gf->nodes[i]->src0->nb[2]; - const uint64_t nb03 = gf->nodes[i]->src0->nb[3]; - - const int64_t ne0 = gf->nodes[i]->ne[0]; - const int64_t ne1 = gf->nodes[i]->ne[1]; - const int64_t ne2 = gf->nodes[i]->ne[2]; - const int64_t ne3 = gf->nodes[i]->ne[3]; - - const uint64_t nb0 = gf->nodes[i]->nb[0]; - const uint64_t nb1 = gf->nodes[i]->nb[1]; - const uint64_t nb2 = gf->nodes[i]->nb[2]; - const uint64_t nb3 = gf->nodes[i]->nb[3]; - - //const int n_past = ((int32_t *) gf->nodes[i]->src1->data)[0]; // TODO: TMP !!!!! const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1]; const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2]; @@ -731,32 +672,6 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); - id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - - const int64_t ne00 = gf->nodes[i]->src0->ne[0]; - const int64_t ne01 = gf->nodes[i]->src0->ne[1]; - const int64_t ne02 = gf->nodes[i]->src0->ne[2]; - const int64_t ne03 = gf->nodes[i]->src0->ne[3]; - - const uint64_t nb00 = gf->nodes[i]->src0->nb[0]; - const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; - const uint64_t nb02 = gf->nodes[i]->src0->nb[2]; - const uint64_t nb03 = gf->nodes[i]->src0->nb[3]; - - const int64_t ne0 = gf->nodes[i]->ne[0]; - const int64_t ne1 = gf->nodes[i]->ne[1]; - const int64_t ne2 = gf->nodes[i]->ne[2]; - const int64_t ne3 = gf->nodes[i]->ne[3]; - - const uint64_t nb0 = gf->nodes[i]->nb[0]; - const uint64_t nb1 = gf->nodes[i]->nb[1]; - const uint64_t nb2 = gf->nodes[i]->nb[2]; - const uint64_t nb3 = gf->nodes[i]->nb[3]; - - const enum ggml_type src0t = gf->nodes[i]->src0->type; - const enum ggml_type dstt = gf->nodes[i]->type; - const int nth = 32; //fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); @@ -835,6 +750,7 @@ int llama_mtl_eval( // TODO const float * logits = ctx->out.contents; +#if 1 printf("logits: "); for (int i = 0; i < 100; i++) { printf("%8.4f ", logits[i]); @@ -851,6 +767,7 @@ int llama_mtl_eval( } } printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax); +#endif //{ // struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check");