metal: faster soft_max vial float4
This commit is contained in:
parent
2699cac032
commit
43ca76976d
2 changed files with 49 additions and 1 deletions
|
@ -63,6 +63,7 @@ struct ggml_metal_context {
|
||||||
GGML_METAL_DECL_KERNEL(relu);
|
GGML_METAL_DECL_KERNEL(relu);
|
||||||
GGML_METAL_DECL_KERNEL(gelu);
|
GGML_METAL_DECL_KERNEL(gelu);
|
||||||
GGML_METAL_DECL_KERNEL(soft_max);
|
GGML_METAL_DECL_KERNEL(soft_max);
|
||||||
|
GGML_METAL_DECL_KERNEL(soft_max_4);
|
||||||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||||
|
@ -207,6 +208,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(relu);
|
GGML_METAL_ADD_KERNEL(relu);
|
||||||
GGML_METAL_ADD_KERNEL(gelu);
|
GGML_METAL_ADD_KERNEL(gelu);
|
||||||
GGML_METAL_ADD_KERNEL(soft_max);
|
GGML_METAL_ADD_KERNEL(soft_max);
|
||||||
|
GGML_METAL_ADD_KERNEL(soft_max_4);
|
||||||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||||
|
@ -273,6 +275,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||||
GGML_METAL_DEL_KERNEL(relu);
|
GGML_METAL_DEL_KERNEL(relu);
|
||||||
GGML_METAL_DEL_KERNEL(gelu);
|
GGML_METAL_DEL_KERNEL(gelu);
|
||||||
GGML_METAL_DEL_KERNEL(soft_max);
|
GGML_METAL_DEL_KERNEL(soft_max);
|
||||||
|
GGML_METAL_DEL_KERNEL(soft_max_4);
|
||||||
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
||||||
|
@ -796,7 +799,11 @@ void ggml_metal_graph_compute(
|
||||||
{
|
{
|
||||||
const int nth = 32;
|
const int nth = 32;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
if (ne00%4 == 0) {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
||||||
|
} else {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||||
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
|
|
|
@ -141,6 +141,47 @@ kernel void kernel_soft_max(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_soft_max_4(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = tgpig[2];
|
||||||
|
const int64_t i02 = tgpig[1];
|
||||||
|
const int64_t i01 = tgpig[0];
|
||||||
|
|
||||||
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
|
// parallel max
|
||||||
|
float4 lmax4 = psrc4[tpitg[0]];
|
||||||
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||||
|
lmax4 = fmax(lmax4, psrc4[i00]);
|
||||||
|
}
|
||||||
|
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||||
|
|
||||||
|
const float max = simd_max(lmax);
|
||||||
|
|
||||||
|
// parallel sum
|
||||||
|
float4 lsum4 = 0.0f;
|
||||||
|
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||||
|
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
||||||
|
lsum4 += exp_psrc4;
|
||||||
|
pdst4[i00] = exp_psrc4;
|
||||||
|
}
|
||||||
|
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||||
|
|
||||||
|
const float sum = simd_sum(lsum);
|
||||||
|
|
||||||
|
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||||
|
pdst4[i00] /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_diag_mask_inf(
|
kernel void kernel_diag_mask_inf(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue