fix softmax
This commit is contained in:
parent
e50517b64f
commit
d70305b343
1 changed files with 16 additions and 4 deletions
|
@ -14,7 +14,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
||||||
|
|
||||||
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||||
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||||
|
const int nthreads = block_size;
|
||||||
|
const int nwarps = nthreads / WARP_SIZE;
|
||||||
|
int nreduce = nwarps / WARP_SIZE;
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
|
@ -27,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
||||||
slope = sycl::pow(base, float(exp));
|
slope = sycl::pow(base, float(exp));
|
||||||
}
|
}
|
||||||
|
|
||||||
float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
|
float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
|
||||||
float max_val = -INFINITY;
|
float max_val = -INFINITY;
|
||||||
|
|
||||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||||
|
@ -51,6 +53,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
||||||
if (block_size > WARP_SIZE) {
|
if (block_size > WARP_SIZE) {
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
buf[lane_id] = -INFINITY;
|
buf[lane_id] = -INFINITY;
|
||||||
|
for (size_t i = 1; i < nreduce; i += 1)
|
||||||
|
buf[lane_id + i * WARP_SIZE] = -INFINITY;
|
||||||
}
|
}
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
@ -58,13 +62,15 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
||||||
buf[warp_id] = max_val;
|
buf[warp_id] = max_val;
|
||||||
}
|
}
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
max_val = buf[lane_id];
|
max_val = buf[lane_id];
|
||||||
|
for (size_t i = 1; i < nreduce; i += 1)
|
||||||
|
{
|
||||||
|
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
|
||||||
|
}
|
||||||
max_val = warp_reduce_max(max_val, item_ct1);
|
max_val = warp_reduce_max(max_val, item_ct1);
|
||||||
}
|
}
|
||||||
|
|
||||||
float tmp = 0.f;
|
float tmp = 0.f;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||||
const int col = col0 + tid;
|
const int col = col0 + tid;
|
||||||
|
@ -83,6 +89,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
buf[lane_id] = 0.f;
|
buf[lane_id] = 0.f;
|
||||||
|
for (size_t i = 1; i < nreduce; i += 1)
|
||||||
|
buf[lane_id + i * WARP_SIZE] = 0.f;
|
||||||
}
|
}
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
|
@ -92,6 +100,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
tmp = buf[lane_id];
|
tmp = buf[lane_id];
|
||||||
|
for (size_t i = 1; i < nreduce; i += 1)
|
||||||
|
{
|
||||||
|
tmp += buf[lane_id + i * WARP_SIZE];
|
||||||
|
}
|
||||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue