fix group_norm_f32
This commit is contained in:
parent
33df09d95b
commit
90e0328038
1 changed files with 11 additions and 4 deletions
|
@ -342,7 +342,7 @@ static void norm_f32(const float * x, float * dst, const int ncols, const float
|
|||
item_ct1.get_local_id(1);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
|
||||
const int nthreads = item_ct1.get_group_range(2);
|
||||
const int nthreads = item_ct1.get_local_range(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
assert(nwarps % WARP_SIZE == 0);
|
||||
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
|
||||
|
@ -456,7 +456,9 @@ static void group_norm_f32(const float * x, float * dst, const int group_size, c
|
|||
const sycl::nd_item<3> &item_ct1, float *s_sum, int block_size) {
|
||||
int start = item_ct1.get_group(2) * group_size;
|
||||
int end = start + group_size;
|
||||
|
||||
const int nthreads = item_ct1.get_local_range(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
assert(nwarps % WARP_SIZE == 0);
|
||||
start += item_ct1.get_local_id(2);
|
||||
|
||||
if (end >= ne_elements) {
|
||||
|
@ -487,7 +489,12 @@ static void group_norm_f32(const float * x, float * dst, const int group_size, c
|
|||
better performance if there is no access to global memory.
|
||||
*/
|
||||
item_ct1.barrier();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = 0.f;
|
||||
int nreduce = nwarps / WARP_SIZE;
|
||||
for (size_t i = 0; i < nreduce; i += 1)
|
||||
{
|
||||
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
}
|
||||
|
||||
|
@ -534,7 +541,7 @@ static void rms_norm_f32(const float * x, float * dst, const int ncols, const fl
|
|||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int nthreads = item_ct1.get_group_range(2);
|
||||
const int nthreads = item_ct1.get_local_range(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
assert(nwarps % WARP_SIZE == 0);
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue