From 9a48a4536a457e252c5800e9968ad1be54a6a19f Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Tue, 25 Jun 2024 16:34:01 +0800 Subject: [PATCH] fix rms_norm with correct warp number --- ggml/src/ggml-sycl.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 28698b610..75bd5c5f7 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -342,6 +342,9 @@ static void norm_f32(const float * x, float * dst, const int ncols, const float item_ct1.get_local_id(1); const int tid = item_ct1.get_local_id(2); + const int nthreads = item_ct1.get_group_range(2); + const int nwarps = nthreads / WARP_SIZE; + assert(nwarps % WARP_SIZE == 0); sycl::float2 mean_var = sycl::float2(0.f, 0.f); for (int col = tid; col < ncols; col += block_size) { @@ -364,7 +367,12 @@ static void norm_f32(const float * x, float * dst, const int ncols, const float converged control flow. You may need to adjust the code. */ item_ct1.barrier(sycl::access::fence_space::local_space); - mean_var = s_sum[lane_id]; + mean_var = 0.f; + int nreduce = nwarps / WARP_SIZE; + for (size_t i = 0; i < nreduce; i+= 1) + { + mean_var += s_sum[lane_id + i * WARP_SIZE]; + } mean_var = warp_reduce_sum(mean_var, item_ct1); }