diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index b3f21f347..ade0719ba 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -32,6 +32,9 @@ struct ggml_mtl_context { id function_get_rows_q4_0; id pipeline_get_rows_q4_0; + + id function_rms_norm; + id pipeline_rms_norm; }; // MSL code @@ -127,6 +130,10 @@ struct ggml_mtl_context * llama_mtl_init( ctx->function_get_rows_q4_0 = [ctx->library newFunctionWithName:@"kernel_get_rows_q4_0"]; ctx->pipeline_get_rows_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_get_rows_q4_0 error:nil]; fprintf(stderr, "%s: loaded kernel_get_rows_q4_0: %p\n", __func__, (void *) ctx->pipeline_get_rows_q4_0); + + ctx->function_rms_norm = [ctx->library newFunctionWithName:@"kernel_rms_norm"]; + ctx->pipeline_rms_norm = [ctx->device newComputePipelineStateWithFunction:ctx->function_rms_norm error:nil]; + fprintf(stderr, "%s: loaded kernel_rms_norm: %p\n", __func__, (void *) ctx->pipeline_rms_norm); } // MTLBuffer approach @@ -348,6 +355,30 @@ int llama_mtl_eval( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_RMS_NORM: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); + id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + + const int64_t ne00 = gf->nodes[i]->src0->ne[0]; + const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; + const float eps = 1e-6f; + + [encoder setComputePipelineState:ctx->pipeline_rms_norm]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + + const int64_t nrows = ggml_nrows(gf->nodes[i]->src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); GGML_ASSERT(false); diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 01ffec018..6a736446b 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -43,23 +43,23 @@ kernel void kernel_add( } kernel void kernel_relu( - device const float * src, + device const float * src0, device float * dst, uint gid[[thread_position_in_grid]]) { - dst[gid] = max(0.0f, src[gid]); + dst[gid] = max(0.0f, src0[gid]); } // TODO: broken kernel void kernel_soft_max( - device const float * src, + device const float * src0, device float * dst) { float max = 0.0f; for (int i = 0; i < nsoftmax; i++) { - max = MAX(max, src[i]); + max = MAX(max, src0[i]); } float sum = 0.0f; for (int i = 0; i < nsoftmax; i++) { - dst[i] = exp(src[i] - max); + dst[i] = exp(src0[i] - max); sum += dst[i]; } for (int i = 0; i < nsoftmax; i++) { @@ -75,8 +75,6 @@ kernel void kernel_get_rows_q4_0( constant uint64_t & nb01, constant uint64_t & nb1, uint gid[[thread_position_in_grid]]) { - device const block_q4_0 * src = (device const block_q4_0 *)src0; - const int i = gid; const int r = ((device int32_t *) src1)[i]; @@ -84,3 +82,26 @@ kernel void kernel_get_rows_q4_0( (device const block_q4_0 *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } + +kernel void kernel_rms_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + uint gid[[thread_position_in_grid]]) { + device const float * x = (device const float *) ((device const char *) src0 + gid*nb01); + + float sum = 0.0f; + for (int i00 = 0; i00 < ne00; i00++) { + sum += x[i00] * x[i00]; + } + + const float mean = sum/ne00; + const float scale = 1.0f/sqrt(mean + eps); + + device float * y = dst + gid*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + y[i00] = x[i00] * scale; + } +} diff --git a/ggml.c b/ggml.c index 4cd0d7211..823d904ee 100644 --- a/ggml.c +++ b/ggml.c @@ -3723,7 +3723,7 @@ int64_t ggml_nelements(const struct ggml_tensor * tensor) { return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; } -int ggml_nrows(const struct ggml_tensor * tensor) { +int64_t ggml_nrows(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; @@ -9245,7 +9245,7 @@ static void ggml_compute_forward_rms_norm_f32( sum += (ggml_float)(x[i00] * x[i00]); } - float mean = sum/ne00; + const float mean = sum/ne00; float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); diff --git a/ggml.h b/ggml.h index 60c0ad8bf..1f033b492 100644 --- a/ggml.h +++ b/ggml.h @@ -425,6 +425,7 @@ extern "C" { GGML_API void ggml_print_objects(const struct ggml_context * ctx); GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor); + GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor); GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor); GGML_API int ggml_blck_size (enum ggml_type type); diff --git a/llama.cpp b/llama.cpp index c5ea19ac9..3ee170e4c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1252,7 +1252,6 @@ static bool llama_eval_internal( memcpy(embd->data, tokens, N*ggml_element_size(embd)); struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); - ggml_set_name(inpL, "mtl-check"); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -1264,16 +1263,15 @@ static bool llama_eval_internal( // norm { cur = ggml_rms_norm(ctx0, inpL); + // TODO: TMP !!!! + if (il == 0) { + ggml_set_name(cur, "mtl-check"); + } // cur = cur*attention_norm(broadcasted) cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm); } - // TODO: TMP !!!! - //if (il == 0) { - // ggml_set_name(cur, "mtl-check"); - //} - // self-attention { // compute Q and K and RoPE them