fix group_norm ut
This commit is contained in:
parent
f09b7cb609
commit
c675aaf0b5
1 changed files with 6 additions and 2 deletions
|
@ -57,6 +57,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
||||||
const int nwarps = nthreads / WARP_SIZE;
|
const int nwarps = nthreads / WARP_SIZE;
|
||||||
assert(nwarps % WARP_SIZE == 0);
|
assert(nwarps % WARP_SIZE == 0);
|
||||||
start += item_ct1.get_local_id(2);
|
start += item_ct1.get_local_id(2);
|
||||||
|
int nreduce = nwarps / WARP_SIZE;
|
||||||
|
|
||||||
if (end >= ne_elements) {
|
if (end >= ne_elements) {
|
||||||
end = ne_elements;
|
end = ne_elements;
|
||||||
|
@ -87,7 +88,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
||||||
*/
|
*/
|
||||||
item_ct1.barrier();
|
item_ct1.barrier();
|
||||||
tmp = 0.f;
|
tmp = 0.f;
|
||||||
int nreduce = nwarps / WARP_SIZE;
|
|
||||||
for (size_t i = 0; i < nreduce; i += 1)
|
for (size_t i = 0; i < nreduce; i += 1)
|
||||||
{
|
{
|
||||||
tmp += s_sum[lane_id + i * WARP_SIZE];
|
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||||
|
@ -122,7 +122,11 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
||||||
better performance if there is no access to global memory.
|
better performance if there is no access to global memory.
|
||||||
*/
|
*/
|
||||||
item_ct1.barrier();
|
item_ct1.barrier();
|
||||||
tmp = s_sum[lane_id];
|
tmp = 0.f;
|
||||||
|
for (size_t i = 0; i < nreduce; i += 1)
|
||||||
|
{
|
||||||
|
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||||
|
}
|
||||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue