Added kernel_mul_mv_f16_f32_l4_large which performs 32x more ops
Most of the time, kernel_mul_mv_f16_f32_l4 is called to perform 4 FP ops per thread. Added kernel_mul_mv_f16_f32_l4_large which performs 128 FP ops per thread, when there are 32x less threads.
This commit is contained in:
parent
d041d2ceaa
commit
cd2322c996
1 changed files with 12 additions and 2 deletions
14
ggml-metal.m
14
ggml-metal.m
|
@ -83,6 +83,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
||||||
|
@ -533,6 +534,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE, mul_mv_f16_f32_l4_large, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
||||||
|
@ -1586,7 +1588,11 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
nth1 = 1;
|
nth1 = 1;
|
||||||
if (src1t == GGML_TYPE_F32) {
|
if (src1t == GGML_TYPE_F32) {
|
||||||
if (ne11 * ne12 < 4) {
|
if (ne11 * ne12 < 4) {
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
if (ne01 > 128) {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE].pipeline;
|
||||||
|
} else {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
||||||
|
}
|
||||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
||||||
nrows = ne11;
|
nrows = ne11;
|
||||||
|
@ -1778,7 +1784,11 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
if (ne01 > 128) {
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01/32, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
} else {
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue