From c675aaf0b5fee8c2e6c5725762154009c9d48be0 Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Tue, 2 Jul 2024 06:55:06 +0000 Subject: [PATCH] fix group_norm ut --- ggml/src/ggml-sycl/norm.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index ed0fa7e31..e0c5dfeca 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -57,6 +57,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con const int nwarps = nthreads / WARP_SIZE; assert(nwarps % WARP_SIZE == 0); start += item_ct1.get_local_id(2); + int nreduce = nwarps / WARP_SIZE; if (end >= ne_elements) { end = ne_elements; @@ -87,7 +88,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con */ item_ct1.barrier(); tmp = 0.f; - int nreduce = nwarps / WARP_SIZE; for (size_t i = 0; i < nreduce; i += 1) { tmp += s_sum[lane_id + i * WARP_SIZE]; @@ -122,7 +122,11 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con better performance if there is no access to global memory. */ item_ct1.barrier(); - tmp = s_sum[lane_id]; + tmp = 0.f; + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += s_sum[lane_id + i * WARP_SIZE]; + } tmp = warp_reduce_sum(tmp, item_ct1); }