metal : fix soft_max kernels
ref: https://github.com/ggerganov/ggml/pull/621/commits/1914017863d2f9ab8ecc0281cc2a56d683668b92
This commit is contained in:
parent
82e4f64578
commit
ab558ac2b3
2 changed files with 13 additions and 7 deletions
|
@ -347,9 +347,9 @@ kernel void kernel_soft_max(
|
|||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
|
||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
// parallel max
|
||||
float lmax = -INFINITY;
|
||||
|
@ -386,6 +386,8 @@ kernel void kernel_soft_max(
|
|||
}
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
|
@ -428,9 +430,9 @@ kernel void kernel_soft_max_4(
|
|||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = -INFINITY;
|
||||
|
@ -468,6 +470,8 @@ kernel void kernel_soft_max_4(
|
|||
}
|
||||
|
||||
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue