diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index b9f14e1ea..5a716d0ee 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -342,7 +342,7 @@ static void norm_f32(const float * x, float * dst, const int ncols, const float item_ct1.get_local_id(1); const int tid = item_ct1.get_local_id(2); - const int nthreads = item_ct1.get_group_range(2); + const int nthreads = item_ct1.get_local_range(2); const int nwarps = nthreads / WARP_SIZE; assert(nwarps % WARP_SIZE == 0); sycl::float2 mean_var = sycl::float2(0.f, 0.f); @@ -456,7 +456,9 @@ static void group_norm_f32(const float * x, float * dst, const int group_size, c const sycl::nd_item<3> &item_ct1, float *s_sum, int block_size) { int start = item_ct1.get_group(2) * group_size; int end = start + group_size; - + const int nthreads = item_ct1.get_local_range(2); + const int nwarps = nthreads / WARP_SIZE; + assert(nwarps % WARP_SIZE == 0); start += item_ct1.get_local_id(2); if (end >= ne_elements) { @@ -487,7 +489,12 @@ static void group_norm_f32(const float * x, float * dst, const int group_size, c better performance if there is no access to global memory. */ item_ct1.barrier(); - tmp = s_sum[lane_id]; + tmp = 0.f; + int nreduce = nwarps / WARP_SIZE; + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += s_sum[lane_id + i * WARP_SIZE]; + } tmp = warp_reduce_sum(tmp, item_ct1); } @@ -534,7 +541,7 @@ static void rms_norm_f32(const float * x, float * dst, const int ncols, const fl const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); const int tid = item_ct1.get_local_id(2); - const int nthreads = item_ct1.get_group_range(2); + const int nthreads = item_ct1.get_local_range(2); const int nwarps = nthreads / WARP_SIZE; assert(nwarps % WARP_SIZE == 0); float tmp = 0.0f; // partial sum for thread in warp