SYCL softmax: Initialize nreduce as size_t
This commit is contained in:
parent
fe5afd4a2d
commit
9129362c6f
1 changed files with 5 additions and 5 deletions
|
@ -16,7 +16,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|||
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||
const int nthreads = block_size;
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
int nreduce = nwarps / WARP_SIZE;
|
||||
size_t nreduce = nwarps / WARP_SIZE;
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
|
@ -53,7 +53,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|||
if (block_size > WARP_SIZE) {
|
||||
if (warp_id == 0) {
|
||||
buf[lane_id] = -INFINITY;
|
||||
for (size_t i = 1; i < (size_t) nreduce; i += 1) {
|
||||
for (size_t i = 1; i < nreduce; i += 1) {
|
||||
buf[lane_id + i * WARP_SIZE] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
max_val = buf[lane_id];
|
||||
for (size_t i = 1; i < (size_t) nreduce; i += 1) {
|
||||
for (size_t i = 1; i < nreduce; i += 1) {
|
||||
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
|
||||
}
|
||||
max_val = warp_reduce_max(max_val, item_ct1);
|
||||
|
@ -89,7 +89,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
if (warp_id == 0) {
|
||||
buf[lane_id] = 0.f;
|
||||
for (size_t i = 1; i < (size_t) nreduce; i += 1) {
|
||||
for (size_t i = 1; i < nreduce; i += 1) {
|
||||
buf[lane_id + i * WARP_SIZE] = 0.f;
|
||||
}
|
||||
}
|
||||
|
@ -101,7 +101,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
tmp = buf[lane_id];
|
||||
for (size_t i = 1; i < (size_t) nreduce; i += 1) {
|
||||
for (size_t i = 1; i < nreduce; i += 1) {
|
||||
tmp += buf[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue