mtl : add silu kernel
This commit is contained in:
parent
a0cc3de59a
commit
42dca4004c
3 changed files with 37 additions and 5 deletions
|
@ -30,6 +30,9 @@ struct ggml_mtl_context {
|
||||||
id<MTLFunction> function_scale;
|
id<MTLFunction> function_scale;
|
||||||
id<MTLComputePipelineState> pipeline_scale;
|
id<MTLComputePipelineState> pipeline_scale;
|
||||||
|
|
||||||
|
id<MTLFunction> function_silu;
|
||||||
|
id<MTLComputePipelineState> pipeline_silu;
|
||||||
|
|
||||||
id<MTLFunction> function_relu;
|
id<MTLFunction> function_relu;
|
||||||
id<MTLComputePipelineState> pipeline_relu;
|
id<MTLComputePipelineState> pipeline_relu;
|
||||||
|
|
||||||
|
@ -148,6 +151,10 @@ struct ggml_mtl_context * llama_mtl_init(
|
||||||
ctx->pipeline_scale = [ctx->device newComputePipelineStateWithFunction:ctx->function_scale error:nil];
|
ctx->pipeline_scale = [ctx->device newComputePipelineStateWithFunction:ctx->function_scale error:nil];
|
||||||
fprintf(stderr, "%s: loaded kernel_scale: %p\n", __func__, (void *) ctx->pipeline_scale);
|
fprintf(stderr, "%s: loaded kernel_scale: %p\n", __func__, (void *) ctx->pipeline_scale);
|
||||||
|
|
||||||
|
ctx->function_silu = [ctx->library newFunctionWithName:@"kernel_silu"];
|
||||||
|
ctx->pipeline_silu = [ctx->device newComputePipelineStateWithFunction:ctx->function_silu error:nil];
|
||||||
|
fprintf(stderr, "%s: loaded kernel_silu: %p\n", __func__, (void *) ctx->pipeline_silu);
|
||||||
|
|
||||||
ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"];
|
ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"];
|
||||||
ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil];
|
ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil];
|
||||||
fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, (void *) ctx->pipeline_relu);
|
fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, (void *) ctx->pipeline_relu);
|
||||||
|
@ -357,6 +364,23 @@ int llama_mtl_eval(
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(gf->nodes[i]);
|
const int64_t n = ggml_nelements(gf->nodes[i]);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_OP_SILU:
|
||||||
|
{
|
||||||
|
if (encoder == nil) {
|
||||||
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
|
}
|
||||||
|
|
||||||
|
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||||
|
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
const int64_t n = ggml_nelements(gf->nodes[i]);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_RELU:
|
case GGML_OP_RELU:
|
||||||
|
|
|
@ -58,6 +58,14 @@ kernel void kernel_scale(
|
||||||
dst[tpig] = src0[tpig] * scale;
|
dst[tpig] = src0[tpig] * scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_silu(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
float x = src0[tpig];
|
||||||
|
dst[tpig] = x / (1.0f + exp(-x));
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_relu(
|
kernel void kernel_relu(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
|
10
llama.cpp
10
llama.cpp
|
@ -1371,11 +1371,6 @@ static bool llama_eval_internal(
|
||||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||||
ggml_set_name(cur, "KQV_merged_contiguous");
|
ggml_set_name(cur, "KQV_merged_contiguous");
|
||||||
|
|
||||||
// TODO: TMP !!!!
|
|
||||||
if (il == 0) {
|
|
||||||
ggml_set_name(cur, "mtl-check");
|
|
||||||
}
|
|
||||||
|
|
||||||
// projection (no bias)
|
// projection (no bias)
|
||||||
cur = ggml_mul_mat(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
model.layers[il].wo,
|
model.layers[il].wo,
|
||||||
|
@ -1407,6 +1402,11 @@ static bool llama_eval_internal(
|
||||||
// SILU activation
|
// SILU activation
|
||||||
cur = ggml_silu(ctx0, cur);
|
cur = ggml_silu(ctx0, cur);
|
||||||
|
|
||||||
|
// TODO: TMP !!!!
|
||||||
|
if (il == 0) {
|
||||||
|
ggml_set_name(cur, "mtl-check");
|
||||||
|
}
|
||||||
|
|
||||||
cur = ggml_mul(ctx0, cur, tmp);
|
cur = ggml_mul(ctx0, cur, tmp);
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue