mtl : add rms_norm kernel + confirm working

This commit is contained in:
Georgi Gerganov 2023-05-30 19:03:04 +03:00
parent 794704e409
commit 72256ebd2b
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 66 additions and 15 deletions

View file

@ -32,6 +32,9 @@ struct ggml_mtl_context {
id<MTLFunction> function_get_rows_q4_0;
id<MTLComputePipelineState> pipeline_get_rows_q4_0;
id<MTLFunction> function_rms_norm;
id<MTLComputePipelineState> pipeline_rms_norm;
};
// MSL code
@ -127,6 +130,10 @@ struct ggml_mtl_context * llama_mtl_init(
ctx->function_get_rows_q4_0 = [ctx->library newFunctionWithName:@"kernel_get_rows_q4_0"];
ctx->pipeline_get_rows_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_get_rows_q4_0 error:nil];
fprintf(stderr, "%s: loaded kernel_get_rows_q4_0: %p\n", __func__, (void *) ctx->pipeline_get_rows_q4_0);
ctx->function_rms_norm = [ctx->library newFunctionWithName:@"kernel_rms_norm"];
ctx->pipeline_rms_norm = [ctx->device newComputePipelineStateWithFunction:ctx->function_rms_norm error:nil];
fprintf(stderr, "%s: loaded kernel_rms_norm: %p\n", __func__, (void *) ctx->pipeline_rms_norm);
}
// MTLBuffer approach
@ -348,6 +355,30 @@ int llama_mtl_eval(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_RMS_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const float eps = 1e-6f;
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
const int64_t nrows = ggml_nrows(gf->nodes[i]->src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
GGML_ASSERT(false);

View file

@ -43,23 +43,23 @@ kernel void kernel_add(
}
kernel void kernel_relu(
device const float * src,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {
dst[gid] = max(0.0f, src[gid]);
dst[gid] = max(0.0f, src0[gid]);
}
// TODO: broken
kernel void kernel_soft_max(
device const float * src,
device const float * src0,
device float * dst) {
float max = 0.0f;
for (int i = 0; i < nsoftmax; i++) {
max = MAX(max, src[i]);
max = MAX(max, src0[i]);
}
float sum = 0.0f;
for (int i = 0; i < nsoftmax; i++) {
dst[i] = exp(src[i] - max);
dst[i] = exp(src0[i] - max);
sum += dst[i];
}
for (int i = 0; i < nsoftmax; i++) {
@ -75,8 +75,6 @@ kernel void kernel_get_rows_q4_0(
constant uint64_t & nb01,
constant uint64_t & nb1,
uint gid[[thread_position_in_grid]]) {
device const block_q4_0 * src = (device const block_q4_0 *)src0;
const int i = gid;
const int r = ((device int32_t *) src1)[i];
@ -84,3 +82,26 @@ kernel void kernel_get_rows_q4_0(
(device const block_q4_0 *) ((device char *) src0 + r*nb01),
(device float *) ((device char *) dst + i*nb1), ne00);
}
kernel void kernel_rms_norm(
device const void * src0,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant float & eps,
uint gid[[thread_position_in_grid]]) {
device const float * x = (device const float *) ((device const char *) src0 + gid*nb01);
float sum = 0.0f;
for (int i00 = 0; i00 < ne00; i00++) {
sum += x[i00] * x[i00];
}
const float mean = sum/ne00;
const float scale = 1.0f/sqrt(mean + eps);
device float * y = dst + gid*ne00;
for (int i00 = 0; i00 < ne00; i00++) {
y[i00] = x[i00] * scale;
}
}

4
ggml.c
View file

@ -3723,7 +3723,7 @@ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
}
int ggml_nrows(const struct ggml_tensor * tensor) {
int64_t ggml_nrows(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
@ -9245,7 +9245,7 @@ static void ggml_compute_forward_rms_norm_f32(
sum += (ggml_float)(x[i00] * x[i00]);
}
float mean = sum/ne00;
const float mean = sum/ne00;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);

1
ggml.h
View file

@ -425,6 +425,7 @@ extern "C" {
GGML_API void ggml_print_objects(const struct ggml_context * ctx);
GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor);
GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
GGML_API int ggml_blck_size (enum ggml_type type);

View file

@ -1252,7 +1252,6 @@ static bool llama_eval_internal(
memcpy(embd->data, tokens, N*ggml_element_size(embd));
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
ggml_set_name(inpL, "mtl-check");
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
@ -1264,16 +1263,15 @@ static bool llama_eval_internal(
// norm
{
cur = ggml_rms_norm(ctx0, inpL);
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(cur, "mtl-check");
}
// cur = cur*attention_norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
}
// TODO: TMP !!!!
//if (il == 0) {
// ggml_set_name(cur, "mtl-check");
//}
// self-attention
{
// compute Q and K and RoPE them