This commit is contained in:
Georgi Gerganov 2024-05-20 17:00:55 +03:00
parent 26cd4237bc
commit a041ced0fd
No known key found for this signature in database
GPG key ID: BF970631944C16B7
4 changed files with 176 additions and 96 deletions

View file

@ -401,8 +401,10 @@ kernel void kernel_soft_max(
// parallel max
float lmax = -INFINITY;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
for (int i00 = 32*tpitg; i00 < ne00; i00 += 32*ntg) {
for (int t = 0; t < 32 && i00 < ne00; ++t, ++i00) {
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
}
// find the max value in the block
@ -426,10 +428,12 @@ kernel void kernel_soft_max(
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
for (int i00 = 32*tpitg; i00 < ne00; i00 += 32*ntg) {
for (int t = 0; t < 32 && i00 < ne00; ++t, ++i00) {
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
}
}
// This barrier fixes a failing test
@ -457,8 +461,10 @@ kernel void kernel_soft_max(
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
pdst[i00] *= inv_sum;
for (int i00 = 32*tpitg; i00 < ne00; i00 += 32*ntg) {
for (int t = 0; t < 32 && i00 < ne00; ++t, ++i00) {
pdst[i00] *= inv_sum;
}
}
}
@ -503,8 +509,10 @@ kernel void kernel_soft_max_4(
// parallel max
float4 lmax4 = -INFINITY;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
for (int i00 = 8*tpitg; i00 < ne00/4; i00 += 8*ntg) {
for (int t = 0; t < 8 && i00 < ne00/4; ++t, ++i00) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
}
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@ -529,10 +537,12 @@ kernel void kernel_soft_max_4(
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
for (int i00 = 8*tpitg; i00 < ne00/4; i00 += 8*ntg) {
for (int t = 0; t < 8 && i00 < ne00/4; ++t, ++i00) {
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
}
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
@ -562,8 +572,10 @@ kernel void kernel_soft_max_4(
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
pdst4[i00] *= inv_sum;
for (int i00 = 8*tpitg; i00 < ne00/4; i00 += 8*ntg) {
for (int t = 0; t < 8 && i00 < ne00/4; ++t, ++i00) {
pdst4[i00] *= inv_sum;
}
}
}