wip
This commit is contained in:
parent
26cd4237bc
commit
a041ced0fd
4 changed files with 176 additions and 96 deletions
|
@ -401,8 +401,10 @@ kernel void kernel_soft_max(
|
|||
// parallel max
|
||||
float lmax = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
for (int i00 = 32*tpitg; i00 < ne00; i00 += 32*ntg) {
|
||||
for (int t = 0; t < 32 && i00 < ne00; ++t, ++i00) {
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
}
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
|
@ -426,10 +428,12 @@ kernel void kernel_soft_max(
|
|||
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
||||
lsum += exp_psrc0;
|
||||
pdst[i00] = exp_psrc0;
|
||||
for (int i00 = 32*tpitg; i00 < ne00; i00 += 32*ntg) {
|
||||
for (int t = 0; t < 32 && i00 < ne00; ++t, ++i00) {
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
||||
lsum += exp_psrc0;
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
}
|
||||
|
||||
// This barrier fixes a failing test
|
||||
|
@ -457,8 +461,10 @@ kernel void kernel_soft_max(
|
|||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
pdst[i00] *= inv_sum;
|
||||
for (int i00 = 32*tpitg; i00 < ne00; i00 += 32*ntg) {
|
||||
for (int t = 0; t < 32 && i00 < ne00; ++t, ++i00) {
|
||||
pdst[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -503,8 +509,10 @@ kernel void kernel_soft_max_4(
|
|||
// parallel max
|
||||
float4 lmax4 = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
||||
for (int i00 = 8*tpitg; i00 < ne00/4; i00 += 8*ntg) {
|
||||
for (int t = 0; t < 8 && i00 < ne00/4; ++t, ++i00) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
@ -529,10 +537,12 @@ kernel void kernel_soft_max_4(
|
|||
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
for (int i00 = 8*tpitg; i00 < ne00/4; i00 += 8*ntg) {
|
||||
for (int t = 0; t < 8 && i00 < ne00/4; ++t, ++i00) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
}
|
||||
|
||||
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||
|
@ -562,8 +572,10 @@ kernel void kernel_soft_max_4(
|
|||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
pdst4[i00] *= inv_sum;
|
||||
for (int i00 = 8*tpitg; i00 < ne00/4; i00 += 8*ntg) {
|
||||
for (int t = 0; t < 8 && i00 < ne00/4; ++t, ++i00) {
|
||||
pdst4[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue