Various other speedups for "small" kernels

This commit is contained in:
Iwan Kawrakow 2023-09-07 18:12:13 +02:00
parent 7c8c6ce085
commit 2699cac032
2 changed files with 15 additions and 52 deletions

View file

@ -762,7 +762,7 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -782,7 +782,7 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -802,7 +802,6 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;

View file

@ -71,10 +71,10 @@ kernel void kernel_scale(
} }
kernel void kernel_silu( kernel void kernel_silu(
device const float * src0, device const float4 * src0,
device float * dst, device float4 * dst,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
float x = src0[tpig]; device const float4 & x = src0[tpig];
dst[tpig] = x / (1.0f + exp(-x)); dst[tpig] = x / (1.0f + exp(-x));
} }
@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
kernel void kernel_gelu( kernel void kernel_gelu(
device const float * src0, device const float4 * src0,
device float * dst, device float4 * dst,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
float x = src0[tpig]; device const float4 & x = src0[tpig];
// BEWARE !!! // BEWARE !!!
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@ -107,7 +107,6 @@ kernel void kernel_soft_max(
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02, constant int64_t & ne02,
threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { uint3 ntg[[threads_per_threadgroup]]) {
@ -119,58 +118,23 @@ kernel void kernel_soft_max(
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max // parallel max
buf[tpitg[0]] = -INFINITY; float lmax = psrc0[tpitg[0]];
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]); lmax = MAX(lmax, psrc0[i00]);
} }
const float max = simd_max(lmax);
// reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg[0]/2; i > 0; i /= 2) {
if (tpitg[0] < i) {
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
// the loop, and when that is done, buf[0] has the correct (synchronized) value
//if (tpitg[0] == 0) {
// buf[0] = buf[0];
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
const float max = buf[0];
// parallel sum // parallel sum
buf[tpitg[0]] = 0.0f; float lsum = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
const float exp_psrc0 = exp(psrc0[i00] - max); const float exp_psrc0 = exp(psrc0[i00] - max);
buf[tpitg[0]] += exp_psrc0; lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not // Remember the result of exp here. exp is expensive, so we really do not
// whish to compute it twice. // whish to compute it twice.
pdst[i00] = exp_psrc0; pdst[i00] = exp_psrc0;
} }
// reduce const float sum = simd_sum(lsum);
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg[0]/2; i > 0; i /= 2) {
if (tpitg[0] < i) {
buf[tpitg[0]] += buf[tpitg[0] + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// broadcast - not needed, see above
//// broadcast
//if (tpitg[0] == 0) {
// buf[0] = buf[0];
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
const float sum = buf[0];
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
pdst[i00] /= sum; pdst[i00] /= sum;