fix rms_norm_f32
This commit is contained in:
parent
9a48a4536a
commit
33df09d95b
1 changed files with 9 additions and 2 deletions
|
@ -534,7 +534,9 @@ static void rms_norm_f32(const float * x, float * dst, const int ncols, const fl
|
||||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||||
item_ct1.get_local_id(1);
|
item_ct1.get_local_id(1);
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int nthreads = item_ct1.get_group_range(2);
|
||||||
|
const int nwarps = nthreads / WARP_SIZE;
|
||||||
|
assert(nwarps % WARP_SIZE == 0);
|
||||||
float tmp = 0.0f; // partial sum for thread in warp
|
float tmp = 0.0f; // partial sum for thread in warp
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += block_size) {
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
@ -556,7 +558,12 @@ static void rms_norm_f32(const float * x, float * dst, const int ncols, const fl
|
||||||
converged control flow. You may need to adjust the code.
|
converged control flow. You may need to adjust the code.
|
||||||
*/
|
*/
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
tmp = s_sum[lane_id];
|
int nreduce = nwarps / WARP_SIZE;
|
||||||
|
tmp = 0.f;
|
||||||
|
for (size_t i = 0; i < nreduce; i += 1)
|
||||||
|
{
|
||||||
|
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||||
|
}
|
||||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue