metal : support unary ops for nelements % 4 != 0

This commit is contained in:
Georgi Gerganov 2024-04-15 22:37:25 +03:00
parent 4256fe6a1a
commit 70482f587f
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 66 additions and 15 deletions

View file

@ -41,8 +41,11 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_TANH,
GGML_METAL_KERNEL_TYPE_RELU,
GGML_METAL_KERNEL_TYPE_GELU,
GGML_METAL_KERNEL_TYPE_GELU_4,
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
GGML_METAL_KERNEL_TYPE_SILU,
GGML_METAL_KERNEL_TYPE_SILU_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
@ -473,8 +476,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
@ -1204,42 +1210,60 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_UNARY_OP_GELU:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
int64_t n = ggml_nelements(dst);
id<MTLComputePipelineState> pipeline = nil;
if (n%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
n /= 4;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_GELU_QUICK:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
int64_t n = ggml_nelements(dst);
id<MTLComputePipelineState> pipeline = nil;
if (n%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
n /= 4;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_SILU:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
int64_t n = ggml_nelements(dst);
id<MTLComputePipelineState> pipeline = nil;
if (n%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
n /= 4;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
{

View file

@ -242,6 +242,15 @@ constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
kernel void kernel_gelu(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
@ -255,6 +264,15 @@ kernel void kernel_gelu(
}
kernel void kernel_gelu_quick(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
kernel void kernel_gelu_quick_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
@ -264,6 +282,14 @@ kernel void kernel_gelu_quick(
}
kernel void kernel_silu(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = x / (1.0f + exp(-x));
}
kernel void kernel_silu_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {

View file

@ -1878,6 +1878,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
// unary ops
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 }));
}
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));