From f9c8ccc12fd05f4409ad112881fdb5b552ece170 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 7 Sep 2023 15:17:43 +0200 Subject: [PATCH] Fix kernel_norm broken by ca82cf7 --- ggml-metal.metal | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 119fcbeb6..4f321d96f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -220,14 +220,10 @@ kernel void kernel_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); } - //// broadcast - //if (tpitg == 0) { - // sum[0] /= ne00; - //} - //threadgroup_barrier(mem_flags::mem_threadgroup); - const float mean = sum[0]; + const float mean = sum[0] / ne00; // recenter and VARIANCE + threadgroup_barrier(mem_flags::mem_threadgroup); device float * y = dst + tgpig*ne00; sum[tpitg] = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { @@ -235,12 +231,6 @@ kernel void kernel_norm( sum[tpitg] += y[i00] * y[i00]; } - //// VARIANCE - //// parallel sum - //sum[tpitg] = 0.0f; - //for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - // sum[tpitg] += y[i00] * y[i00]; - //} // reduce threadgroup_barrier(mem_flags::mem_threadgroup); for (uint i = ntg/2; i > 0; i /= 2) { @@ -249,12 +239,7 @@ kernel void kernel_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); } - //// broadcast - //if (tpitg == 0) { - // sum[0] /= ne00; - //} - //threadgroup_barrier(mem_flags::mem_threadgroup); - const float variance = sum[0]; + const float variance = sum[0] / ne00; const float scale = 1.0f/sqrt(variance + eps); for (int i00 = tpitg; i00 < ne00; i00 += ntg) {