mtl : add rms_norm kernel + confirm working
This commit is contained in:
parent
794704e409
commit
72256ebd2b
5 changed files with 66 additions and 15 deletions
|
@ -32,6 +32,9 @@ struct ggml_mtl_context {
|
|||
|
||||
id<MTLFunction> function_get_rows_q4_0;
|
||||
id<MTLComputePipelineState> pipeline_get_rows_q4_0;
|
||||
|
||||
id<MTLFunction> function_rms_norm;
|
||||
id<MTLComputePipelineState> pipeline_rms_norm;
|
||||
};
|
||||
|
||||
// MSL code
|
||||
|
@ -127,6 +130,10 @@ struct ggml_mtl_context * llama_mtl_init(
|
|||
ctx->function_get_rows_q4_0 = [ctx->library newFunctionWithName:@"kernel_get_rows_q4_0"];
|
||||
ctx->pipeline_get_rows_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_get_rows_q4_0 error:nil];
|
||||
fprintf(stderr, "%s: loaded kernel_get_rows_q4_0: %p\n", __func__, (void *) ctx->pipeline_get_rows_q4_0);
|
||||
|
||||
ctx->function_rms_norm = [ctx->library newFunctionWithName:@"kernel_rms_norm"];
|
||||
ctx->pipeline_rms_norm = [ctx->device newComputePipelineStateWithFunction:ctx->function_rms_norm error:nil];
|
||||
fprintf(stderr, "%s: loaded kernel_rms_norm: %p\n", __func__, (void *) ctx->pipeline_rms_norm);
|
||||
}
|
||||
|
||||
// MTLBuffer approach
|
||||
|
@ -348,6 +355,30 @@ int llama_mtl_eval(
|
|||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||
|
||||
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
||||
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
||||
const float eps = 1e-6f;
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||
|
||||
const int64_t nrows = ggml_nrows(gf->nodes[i]->src0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||
GGML_ASSERT(false);
|
||||
|
|
|
@ -43,23 +43,23 @@ kernel void kernel_add(
|
|||
}
|
||||
|
||||
kernel void kernel_relu(
|
||||
device const float * src,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
dst[gid] = max(0.0f, src[gid]);
|
||||
dst[gid] = max(0.0f, src0[gid]);
|
||||
}
|
||||
|
||||
// TODO: broken
|
||||
kernel void kernel_soft_max(
|
||||
device const float * src,
|
||||
device const float * src0,
|
||||
device float * dst) {
|
||||
float max = 0.0f;
|
||||
for (int i = 0; i < nsoftmax; i++) {
|
||||
max = MAX(max, src[i]);
|
||||
max = MAX(max, src0[i]);
|
||||
}
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < nsoftmax; i++) {
|
||||
dst[i] = exp(src[i] - max);
|
||||
dst[i] = exp(src0[i] - max);
|
||||
sum += dst[i];
|
||||
}
|
||||
for (int i = 0; i < nsoftmax; i++) {
|
||||
|
@ -75,8 +75,6 @@ kernel void kernel_get_rows_q4_0(
|
|||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb1,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
device const block_q4_0 * src = (device const block_q4_0 *)src0;
|
||||
|
||||
const int i = gid;
|
||||
const int r = ((device int32_t *) src1)[i];
|
||||
|
||||
|
@ -84,3 +82,26 @@ kernel void kernel_get_rows_q4_0(
|
|||
(device const block_q4_0 *) ((device char *) src0 + r*nb01),
|
||||
(device float *) ((device char *) dst + i*nb1), ne00);
|
||||
}
|
||||
|
||||
kernel void kernel_rms_norm(
|
||||
device const void * src0,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant float & eps,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
device const float * x = (device const float *) ((device const char *) src0 + gid*nb01);
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
sum += x[i00] * x[i00];
|
||||
}
|
||||
|
||||
const float mean = sum/ne00;
|
||||
const float scale = 1.0f/sqrt(mean + eps);
|
||||
|
||||
device float * y = dst + gid*ne00;
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
}
|
||||
|
|
4
ggml.c
4
ggml.c
|
@ -3723,7 +3723,7 @@ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
|
|||
return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
|
||||
}
|
||||
|
||||
int ggml_nrows(const struct ggml_tensor * tensor) {
|
||||
int64_t ggml_nrows(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
|
||||
|
@ -9245,7 +9245,7 @@ static void ggml_compute_forward_rms_norm_f32(
|
|||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
}
|
||||
|
||||
float mean = sum/ne00;
|
||||
const float mean = sum/ne00;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
|
|
1
ggml.h
1
ggml.h
|
@ -425,6 +425,7 @@ extern "C" {
|
|||
GGML_API void ggml_print_objects(const struct ggml_context * ctx);
|
||||
|
||||
GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor);
|
||||
GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
|
||||
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API int ggml_blck_size (enum ggml_type type);
|
||||
|
|
10
llama.cpp
10
llama.cpp
|
@ -1252,7 +1252,6 @@ static bool llama_eval_internal(
|
|||
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
||||
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
|
||||
ggml_set_name(inpL, "mtl-check");
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
@ -1264,16 +1263,15 @@ static bool llama_eval_internal(
|
|||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
// TODO: TMP !!!!
|
||||
if (il == 0) {
|
||||
ggml_set_name(cur, "mtl-check");
|
||||
}
|
||||
|
||||
// cur = cur*attention_norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
|
||||
}
|
||||
|
||||
// TODO: TMP !!!!
|
||||
//if (il == 0) {
|
||||
// ggml_set_name(cur, "mtl-check");
|
||||
//}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue