mtl : add non-broadcast mul kernel
This commit is contained in:
parent
42dca4004c
commit
fbd3f6258d
3 changed files with 27 additions and 4 deletions
|
@ -27,6 +27,10 @@ struct ggml_mtl_context {
|
||||||
id<MTLFunction> function_mul;
|
id<MTLFunction> function_mul;
|
||||||
id<MTLComputePipelineState> pipeline_mul;
|
id<MTLComputePipelineState> pipeline_mul;
|
||||||
|
|
||||||
|
// TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
||||||
|
id<MTLFunction> function_mul_row;
|
||||||
|
id<MTLComputePipelineState> pipeline_mul_row;
|
||||||
|
|
||||||
id<MTLFunction> function_scale;
|
id<MTLFunction> function_scale;
|
||||||
id<MTLComputePipelineState> pipeline_scale;
|
id<MTLComputePipelineState> pipeline_scale;
|
||||||
|
|
||||||
|
@ -147,6 +151,10 @@ struct ggml_mtl_context * llama_mtl_init(
|
||||||
ctx->pipeline_mul = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul error:nil];
|
ctx->pipeline_mul = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul error:nil];
|
||||||
fprintf(stderr, "%s: loaded kernel_mul: %p\n", __func__, (void *) ctx->pipeline_mul);
|
fprintf(stderr, "%s: loaded kernel_mul: %p\n", __func__, (void *) ctx->pipeline_mul);
|
||||||
|
|
||||||
|
ctx->function_mul_row = [ctx->library newFunctionWithName:@"kernel_mul_row"];
|
||||||
|
ctx->pipeline_mul_row = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_row error:nil];
|
||||||
|
fprintf(stderr, "%s: loaded kernel_mul_row: %p\n", __func__, (void *) ctx->pipeline_mul_row);
|
||||||
|
|
||||||
ctx->function_scale = [ctx->library newFunctionWithName:@"kernel_scale"];
|
ctx->function_scale = [ctx->library newFunctionWithName:@"kernel_scale"];
|
||||||
ctx->pipeline_scale = [ctx->device newComputePipelineStateWithFunction:ctx->function_scale error:nil];
|
ctx->pipeline_scale = [ctx->device newComputePipelineStateWithFunction:ctx->function_scale error:nil];
|
||||||
fprintf(stderr, "%s: loaded kernel_scale: %p\n", __func__, (void *) ctx->pipeline_scale);
|
fprintf(stderr, "%s: loaded kernel_scale: %p\n", __func__, (void *) ctx->pipeline_scale);
|
||||||
|
@ -336,7 +344,14 @@ int llama_mtl_eval(
|
||||||
|
|
||||||
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul];
|
const int64_t ne10 = gf->nodes[i]->src1->ne[0];
|
||||||
|
|
||||||
|
if (ggml_nelements(gf->nodes[i]->src1) == ne10) {
|
||||||
|
// src1 is a row
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
||||||
|
} else {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul];
|
||||||
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
|
|
@ -39,9 +39,17 @@ kernel void kernel_add(
|
||||||
dst[tpig] = src0[tpig] + src1[tpig];
|
dst[tpig] = src0[tpig] + src1[tpig];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_mul(
|
||||||
|
device const float * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device float * dst,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
dst[tpig] = src0[tpig] * src1[tpig];
|
||||||
|
}
|
||||||
|
|
||||||
// assumption: src1 is a row
|
// assumption: src1 is a row
|
||||||
// broadcast src1 into src0
|
// broadcast src1 into src0
|
||||||
kernel void kernel_mul(
|
kernel void kernel_mul_row(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
|
|
@ -1402,13 +1402,13 @@ static bool llama_eval_internal(
|
||||||
// SILU activation
|
// SILU activation
|
||||||
cur = ggml_silu(ctx0, cur);
|
cur = ggml_silu(ctx0, cur);
|
||||||
|
|
||||||
|
cur = ggml_mul(ctx0, cur, tmp);
|
||||||
|
|
||||||
// TODO: TMP !!!!
|
// TODO: TMP !!!!
|
||||||
if (il == 0) {
|
if (il == 0) {
|
||||||
ggml_set_name(cur, "mtl-check");
|
ggml_set_name(cur, "mtl-check");
|
||||||
}
|
}
|
||||||
|
|
||||||
cur = ggml_mul(ctx0, cur, tmp);
|
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
model.layers[il].w2,
|
model.layers[il].w2,
|
||||||
cur);
|
cur);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue