diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index 06d8961ee..bb0074a4c 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -378,11 +378,19 @@ int llama_mtl_eval( id id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + const int64_t ne00 = gf->nodes[i]->src0->ne[0]; + const int64_t ne01 = gf->nodes[i]->src0->ne[1]; + const int64_t ne02 = gf->nodes[i]->src0->ne[2]; + const int64_t ne03 = gf->nodes[i]->src0->ne[3]; + [encoder setComputePipelineState:ctx->pipeline_soft_max]; [encoder setBuffer:id_src offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_DIAG_MASK_INF: { diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index ef2b690c1..32e850297 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -31,9 +31,6 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i } } -// TODO: not needed -constant int nsoftmax [[function_constant(0)]]; - kernel void kernel_add( device const float * src0, device const float * src1, @@ -68,42 +65,34 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } -// TODO: broken kernel void kernel_soft_max( device const float * src0, - device float * dst) { + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + uint3 tpig[[thread_position_in_grid]]) { + const int64_t i03 = tpig[2]; + const int64_t i02 = tpig[1]; + const int64_t i01 = tpig[0]; + + device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + float max = 0.0f; - for (int i = 0; i < nsoftmax; i++) { - max = MAX(max, src0[i]); + for (int i = 0; i < ne00; i++) { + max = MAX(max, psrc0[i]); } float sum = 0.0f; - for (int i = 0; i < nsoftmax; i++) { - dst[i] = exp(src0[i] - max); - sum += dst[i]; + for (int i = 0; i < ne00; i++) { + pdst[i] = exp(psrc0[i] - max); + sum += pdst[i]; } - for (int i = 0; i < nsoftmax; i++) { - dst[i] /= sum; + for (int i = 0; i < ne00; i++) { + pdst[i] /= sum; } } -//const int n = ggml_nrows(src0); -//const int nc = src0->ne[0]; -//const int nr = src0->ne[1]; -//const int nz = n/nr; -// -//assert( dst->nb[0] == sizeof(float)); -//assert(src0->nb[0] == sizeof(float)); -// -//for (int k = 0; k < nz; k++) { -// for (int j = ith; j < nr; j += nth) { -// for (int i = n_past; i < nc; i++) { -// if (i > n_past + j) { -// *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value; -// } -// } -// } -//} - kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, diff --git a/llama.cpp b/llama.cpp index ff4268ed6..6825636c8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1336,15 +1336,15 @@ static bool llama_eval_internal( // KQ_masked = mask_past(KQ_scaled) struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); ggml_set_name(KQ_masked, "KQ_masked"); - // TODO: TMP !!!! - if (il == 0) { - ggml_set_name(KQ_masked, "mtl-check"); - } // KQ = soft_max(KQ_masked) struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); ggml_set_name(KQ_soft_max, "KQ_soft_max"); + // TODO: TMP !!!! + if (il == 0) { + ggml_set_name(KQ_soft_max, "mtl-check"); + } // split cached V into n_head heads struct ggml_tensor * V =