mtl : add mul kernel + confirm working

This commit is contained in:
Georgi Gerganov 2023-05-30 19:15:38 +03:00
parent 72256ebd2b
commit 64afc0b53a
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 44 additions and 4 deletions

View file

@ -24,6 +24,9 @@ struct ggml_mtl_context {
id<MTLFunction> function_add;
id<MTLComputePipelineState> pipeline_add;
id<MTLFunction> function_mul;
id<MTLComputePipelineState> pipeline_mul;
id<MTLFunction> function_relu;
id<MTLComputePipelineState> pipeline_relu;
@ -119,6 +122,10 @@ struct ggml_mtl_context * llama_mtl_init(
ctx->pipeline_add = [ctx->device newComputePipelineStateWithFunction:ctx->function_add error:nil];
fprintf(stderr, "%s: loaded kernel_add: %p\n", __func__, (void *) ctx->pipeline_add);
ctx->function_mul = [ctx->library newFunctionWithName:@"kernel_mul"];
ctx->pipeline_mul = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul error:nil];
fprintf(stderr, "%s: loaded kernel_mul: %p\n", __func__, (void *) ctx->pipeline_mul);
ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"];
ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil];
fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, (void *) ctx->pipeline_relu);
@ -253,6 +260,28 @@ int llama_mtl_eval(
const int64_t n = ggml_nelements(gf->nodes[i]);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_MUL:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
[encoder setComputePipelineState:ctx->pipeline_mul];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
const int64_t n = ggml_nelements(gf->nodes[i]);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_RELU:

View file

@ -42,6 +42,17 @@ kernel void kernel_add(
dst[gid] = src0[gid] + src1[gid];
}
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_mul(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
uint gid[[thread_position_in_grid]]) {
dst[gid] = src0[gid] * src1[gid % ne00];
}
kernel void kernel_relu(
device const float * src0,
device float * dst,

View file

@ -1263,13 +1263,13 @@ static bool llama_eval_internal(
// norm
{
cur = ggml_rms_norm(ctx0, inpL);
// cur = cur*attention_norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(cur, "mtl-check");
}
// cur = cur*attention_norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
}
// self-attention