fix rms_norm with correct warp number

This commit is contained in:
luoyu-intel 2024-06-25 16:34:01 +08:00
parent a85ad06eb4
commit 9a48a4536a

View file

@ -342,6 +342,9 @@ static void norm_f32(const float * x, float * dst, const int ncols, const float
item_ct1.get_local_id(1); item_ct1.get_local_id(1);
const int tid = item_ct1.get_local_id(2); const int tid = item_ct1.get_local_id(2);
const int nthreads = item_ct1.get_group_range(2);
const int nwarps = nthreads / WARP_SIZE;
assert(nwarps % WARP_SIZE == 0);
sycl::float2 mean_var = sycl::float2(0.f, 0.f); sycl::float2 mean_var = sycl::float2(0.f, 0.f);
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
@ -364,7 +367,12 @@ static void norm_f32(const float * x, float * dst, const int ncols, const float
converged control flow. You may need to adjust the code. converged control flow. You may need to adjust the code.
*/ */
item_ct1.barrier(sycl::access::fence_space::local_space); item_ct1.barrier(sycl::access::fence_space::local_space);
mean_var = s_sum[lane_id]; mean_var = 0.f;
int nreduce = nwarps / WARP_SIZE;
for (size_t i = 0; i < nreduce; i+= 1)
{
mean_var += s_sum[lane_id + i * WARP_SIZE];
}
mean_var = warp_reduce_sum(mean_var, item_ct1); mean_var = warp_reduce_sum(mean_var, item_ct1);
} }