ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
This commit is contained in:
parent
c11d05fec0
commit
f725ca90fb
6 changed files with 105 additions and 34 deletions
29
ggml-metal.m
29
ggml-metal.m
|
@ -46,8 +46,10 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
||||
GGML_METAL_KERNEL_TYPE_SILU,
|
||||
GGML_METAL_KERNEL_TYPE_SILU_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
|
||||
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
||||
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
||||
|
@ -492,8 +494,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
||||
|
@ -1346,22 +1350,33 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
||||
|
||||
if (ne00%4 == 0) {
|
||||
while (nth < ne00/4 && nth < 256) {
|
||||
nth *= 2;
|
||||
}
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
|
||||
if (use_f16) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
||||
}
|
||||
} else {
|
||||
while (nth < ne00 && nth < 1024) {
|
||||
nth *= 2;
|
||||
}
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
||||
if (use_f16) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
|
||||
}
|
||||
}
|
||||
|
||||
float scale;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue