metal : implement soft_max_ext
This commit is contained in:
parent
1f5cd83275
commit
e89597c062
5 changed files with 75 additions and 31 deletions
|
@ -180,10 +180,12 @@ kernel void kernel_gelu(
|
|||
|
||||
kernel void kernel_soft_max(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
|
@ -194,14 +196,15 @@ kernel void kernel_soft_max(
|
|||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
|
||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
// parallel max
|
||||
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
|
||||
float lmax = (tpitg < ne00) ? (psrc0[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY;
|
||||
|
||||
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
|
||||
lmax = MAX(lmax, psrc0[i00]);
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||
}
|
||||
|
||||
float max = simd_max(lmax);
|
||||
|
@ -225,7 +228,7 @@ kernel void kernel_soft_max(
|
|||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
const float exp_psrc0 = exp(psrc0[i00] - max);
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
|
||||
lsum += exp_psrc0;
|
||||
// Remember the result of exp here. exp is expensive, so we really do not
|
||||
// wish to compute it twice.
|
||||
|
@ -257,10 +260,12 @@ kernel void kernel_soft_max(
|
|||
|
||||
kernel void kernel_soft_max_4(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
|
@ -271,14 +276,15 @@ kernel void kernel_soft_max_4(
|
|||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
|
||||
float4 lmax4 = tpitg < ne00/4 ? (psrc4[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY;
|
||||
|
||||
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]);
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
@ -303,7 +309,7 @@ kernel void kernel_soft_max_4(
|
|||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue