mtl : add rms_norm kernel + confirm working

This commit is contained in:
Georgi Gerganov 2023-05-30 19:03:04 +03:00
parent 794704e409
commit 72256ebd2b
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 66 additions and 15 deletions

View file

@ -32,6 +32,9 @@ struct ggml_mtl_context {
id<MTLFunction> function_get_rows_q4_0; id<MTLFunction> function_get_rows_q4_0;
id<MTLComputePipelineState> pipeline_get_rows_q4_0; id<MTLComputePipelineState> pipeline_get_rows_q4_0;
id<MTLFunction> function_rms_norm;
id<MTLComputePipelineState> pipeline_rms_norm;
}; };
// MSL code // MSL code
@ -127,6 +130,10 @@ struct ggml_mtl_context * llama_mtl_init(
ctx->function_get_rows_q4_0 = [ctx->library newFunctionWithName:@"kernel_get_rows_q4_0"]; ctx->function_get_rows_q4_0 = [ctx->library newFunctionWithName:@"kernel_get_rows_q4_0"];
ctx->pipeline_get_rows_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_get_rows_q4_0 error:nil]; ctx->pipeline_get_rows_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_get_rows_q4_0 error:nil];
fprintf(stderr, "%s: loaded kernel_get_rows_q4_0: %p\n", __func__, (void *) ctx->pipeline_get_rows_q4_0); fprintf(stderr, "%s: loaded kernel_get_rows_q4_0: %p\n", __func__, (void *) ctx->pipeline_get_rows_q4_0);
ctx->function_rms_norm = [ctx->library newFunctionWithName:@"kernel_rms_norm"];
ctx->pipeline_rms_norm = [ctx->device newComputePipelineStateWithFunction:ctx->function_rms_norm error:nil];
fprintf(stderr, "%s: loaded kernel_rms_norm: %p\n", __func__, (void *) ctx->pipeline_rms_norm);
} }
// MTLBuffer approach // MTLBuffer approach
@ -348,6 +355,30 @@ int llama_mtl_eval(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_RMS_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const float eps = 1e-6f;
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
const int64_t nrows = ggml_nrows(gf->nodes[i]->src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default: default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
GGML_ASSERT(false); GGML_ASSERT(false);

View file

@ -43,23 +43,23 @@ kernel void kernel_add(
} }
kernel void kernel_relu( kernel void kernel_relu(
device const float * src, device const float * src0,
device float * dst, device float * dst,
uint gid[[thread_position_in_grid]]) { uint gid[[thread_position_in_grid]]) {
dst[gid] = max(0.0f, src[gid]); dst[gid] = max(0.0f, src0[gid]);
} }
// TODO: broken // TODO: broken
kernel void kernel_soft_max( kernel void kernel_soft_max(
device const float * src, device const float * src0,
device float * dst) { device float * dst) {
float max = 0.0f; float max = 0.0f;
for (int i = 0; i < nsoftmax; i++) { for (int i = 0; i < nsoftmax; i++) {
max = MAX(max, src[i]); max = MAX(max, src0[i]);
} }
float sum = 0.0f; float sum = 0.0f;
for (int i = 0; i < nsoftmax; i++) { for (int i = 0; i < nsoftmax; i++) {
dst[i] = exp(src[i] - max); dst[i] = exp(src0[i] - max);
sum += dst[i]; sum += dst[i];
} }
for (int i = 0; i < nsoftmax; i++) { for (int i = 0; i < nsoftmax; i++) {
@ -75,8 +75,6 @@ kernel void kernel_get_rows_q4_0(
constant uint64_t & nb01, constant uint64_t & nb01,
constant uint64_t & nb1, constant uint64_t & nb1,
uint gid[[thread_position_in_grid]]) { uint gid[[thread_position_in_grid]]) {
device const block_q4_0 * src = (device const block_q4_0 *)src0;
const int i = gid; const int i = gid;
const int r = ((device int32_t *) src1)[i]; const int r = ((device int32_t *) src1)[i];
@ -84,3 +82,26 @@ kernel void kernel_get_rows_q4_0(
(device const block_q4_0 *) ((device char *) src0 + r*nb01), (device const block_q4_0 *) ((device char *) src0 + r*nb01),
(device float *) ((device char *) dst + i*nb1), ne00); (device float *) ((device char *) dst + i*nb1), ne00);
} }
kernel void kernel_rms_norm(
device const void * src0,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant float & eps,
uint gid[[thread_position_in_grid]]) {
device const float * x = (device const float *) ((device const char *) src0 + gid*nb01);
float sum = 0.0f;
for (int i00 = 0; i00 < ne00; i00++) {
sum += x[i00] * x[i00];
}
const float mean = sum/ne00;
const float scale = 1.0f/sqrt(mean + eps);
device float * y = dst + gid*ne00;
for (int i00 = 0; i00 < ne00; i00++) {
y[i00] = x[i00] * scale;
}
}

4
ggml.c
View file

@ -3723,7 +3723,7 @@ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
} }
int ggml_nrows(const struct ggml_tensor * tensor) { int64_t ggml_nrows(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
@ -9245,7 +9245,7 @@ static void ggml_compute_forward_rms_norm_f32(
sum += (ggml_float)(x[i00] * x[i00]); sum += (ggml_float)(x[i00] * x[i00]);
} }
float mean = sum/ne00; const float mean = sum/ne00;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);

1
ggml.h
View file

@ -425,6 +425,7 @@ extern "C" {
GGML_API void ggml_print_objects(const struct ggml_context * ctx); GGML_API void ggml_print_objects(const struct ggml_context * ctx);
GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor); GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor);
GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor); GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
GGML_API int ggml_blck_size (enum ggml_type type); GGML_API int ggml_blck_size (enum ggml_type type);

View file

@ -1252,7 +1252,6 @@ static bool llama_eval_internal(
memcpy(embd->data, tokens, N*ggml_element_size(embd)); memcpy(embd->data, tokens, N*ggml_element_size(embd));
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
ggml_set_name(inpL, "mtl-check");
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL; struct ggml_tensor * inpSA = inpL;
@ -1264,16 +1263,15 @@ static bool llama_eval_internal(
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL);
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(cur, "mtl-check");
}
// cur = cur*attention_norm(broadcasted) // cur = cur*attention_norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm); cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
} }
// TODO: TMP !!!!
//if (il == 0) {
// ggml_set_name(cur, "mtl-check");
//}
// self-attention // self-attention
{ {
// compute Q and K and RoPE them // compute Q and K and RoPE them