diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 75bd5c5f7..b9f14e1ea 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -534,7 +534,9 @@ static void rms_norm_f32(const float * x, float * dst, const int ncols, const fl const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); const int tid = item_ct1.get_local_id(2); - + const int nthreads = item_ct1.get_group_range(2); + const int nwarps = nthreads / WARP_SIZE; + assert(nwarps % WARP_SIZE == 0); float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { @@ -556,7 +558,12 @@ static void rms_norm_f32(const float * x, float * dst, const int ncols, const fl converged control flow. You may need to adjust the code. */ item_ct1.barrier(sycl::access::fence_space::local_space); - tmp = s_sum[lane_id]; + int nreduce = nwarps / WARP_SIZE; + tmp = 0.f; + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += s_sum[lane_id + i * WARP_SIZE]; + } tmp = warp_reduce_sum(tmp, item_ct1); }