mtl : add scale kernel

This commit is contained in:
Georgi Gerganov 2023-06-01 19:52:32 +03:00
parent 51efb59437
commit 0f1c580860
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 39 additions and 4 deletions

View file

@ -27,6 +27,9 @@ struct ggml_mtl_context {
id<MTLFunction> function_mul; id<MTLFunction> function_mul;
id<MTLComputePipelineState> pipeline_mul; id<MTLComputePipelineState> pipeline_mul;
id<MTLFunction> function_scale;
id<MTLComputePipelineState> pipeline_scale;
id<MTLFunction> function_relu; id<MTLFunction> function_relu;
id<MTLComputePipelineState> pipeline_relu; id<MTLComputePipelineState> pipeline_relu;
@ -135,6 +138,10 @@ struct ggml_mtl_context * llama_mtl_init(
ctx->pipeline_mul = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul error:nil]; ctx->pipeline_mul = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul error:nil];
fprintf(stderr, "%s: loaded kernel_mul: %p\n", __func__, (void *) ctx->pipeline_mul); fprintf(stderr, "%s: loaded kernel_mul: %p\n", __func__, (void *) ctx->pipeline_mul);
ctx->function_scale = [ctx->library newFunctionWithName:@"kernel_scale"];
ctx->pipeline_scale = [ctx->device newComputePipelineStateWithFunction:ctx->function_scale error:nil];
fprintf(stderr, "%s: loaded kernel_scale: %p\n", __func__, (void *) ctx->pipeline_scale);
ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"]; ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"];
ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil]; ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil];
fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, (void *) ctx->pipeline_relu); fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, (void *) ctx->pipeline_relu);
@ -310,6 +317,26 @@ int llama_mtl_eval(
const int64_t n = ggml_nelements(gf->nodes[i]); const int64_t n = ggml_nelements(gf->nodes[i]);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SCALE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
const float scale = *(const float *) gf->nodes[i]->src1->data;
[encoder setComputePipelineState:ctx->pipeline_scale];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
const int64_t n = ggml_nelements(gf->nodes[i]);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_RELU: case GGML_OP_RELU:

View file

@ -53,6 +53,14 @@ kernel void kernel_mul(
dst[tpig] = src0[tpig] * src1[tpig % ne00]; dst[tpig] = src0[tpig] * src1[tpig % ne00];
} }
kernel void kernel_scale(
device const float * src0,
device float * dst,
constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}
kernel void kernel_relu( kernel void kernel_relu(
device const float * src0, device const float * src0,
device float * dst, device float * dst,

View file

@ -1324,10 +1324,6 @@ static bool llama_eval_internal(
// K * Q // K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
ggml_set_name(KQ, "KQ"); ggml_set_name(KQ, "KQ");
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(KQ, "mtl-check");
}
// KQ_scaled = KQ / sqrt(n_embd/n_head) // KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)); struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
@ -1336,6 +1332,10 @@ static bool llama_eval_internal(
// KQ_scaled shape [n_past + N, N, n_head, 1] // KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
ggml_set_name(KQ_scaled, "KQ_scaled"); ggml_set_name(KQ_scaled, "KQ_scaled");
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(KQ_scaled, "mtl-check");
}
// KQ_masked = mask_past(KQ_scaled) // KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);