fix group_norm_f32

This commit is contained in:
luoyu-intel 2024-06-25 16:56:16 +08:00
parent 33df09d95b
commit 90e0328038

View file

@ -342,7 +342,7 @@ static void norm_f32(const float * x, float * dst, const int ncols, const float
item_ct1.get_local_id(1);
const int tid = item_ct1.get_local_id(2);
const int nthreads = item_ct1.get_group_range(2);
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
assert(nwarps % WARP_SIZE == 0);
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
@ -456,7 +456,9 @@ static void group_norm_f32(const float * x, float * dst, const int group_size, c
const sycl::nd_item<3> &item_ct1, float *s_sum, int block_size) {
int start = item_ct1.get_group(2) * group_size;
int end = start + group_size;
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
assert(nwarps % WARP_SIZE == 0);
start += item_ct1.get_local_id(2);
if (end >= ne_elements) {
@ -487,7 +489,12 @@ static void group_norm_f32(const float * x, float * dst, const int group_size, c
better performance if there is no access to global memory.
*/
item_ct1.barrier();
tmp = s_sum[lane_id];
tmp = 0.f;
int nreduce = nwarps / WARP_SIZE;
for (size_t i = 0; i < nreduce; i += 1)
{
tmp += s_sum[lane_id + i * WARP_SIZE];
}
tmp = warp_reduce_sum(tmp, item_ct1);
}
@ -534,7 +541,7 @@ static void rms_norm_f32(const float * x, float * dst, const int ncols, const fl
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
const int tid = item_ct1.get_local_id(2);
const int nthreads = item_ct1.get_group_range(2);
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
assert(nwarps % WARP_SIZE == 0);
float tmp = 0.0f; // partial sum for thread in warp