mtl : add rope kernel

This commit is contained in:
Georgi Gerganov 2023-05-31 22:28:59 +03:00
parent 6af6a05663
commit 1213af76ce
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 145 additions and 9 deletions

View file

@ -41,6 +41,9 @@ struct ggml_mtl_context {
id<MTLFunction> function_mul_mat_q4_0;
id<MTLComputePipelineState> pipeline_mul_mat_q4_0;
id<MTLFunction> function_rope;
id<MTLComputePipelineState> pipeline_rope;
};
// MSL code
@ -148,6 +151,10 @@ struct ggml_mtl_context * llama_mtl_init(
ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"];
ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil];
fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0);
ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"];
ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil];
fprintf(stderr, "%s: loaded kernel_rope: %p\n", __func__, (void *) ctx->pipeline_rope);
}
// MTLBuffer approach
@ -250,6 +257,10 @@ int llama_mtl_eval(
fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
switch (gf->nodes[i]->op) {
case GGML_OP_RESHAPE:
{
// noop
} break;
case GGML_OP_ADD:
{
if (encoder == nil) {
@ -453,6 +464,68 @@ int llama_mtl_eval(
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_ROPE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
const uint64_t nb03 = gf->nodes[i]->src0->nb[3];
const int64_t ne0 = gf->nodes[i]->ne[0];
const int64_t ne1 = gf->nodes[i]->ne[1];
const int64_t ne2 = gf->nodes[i]->ne[2];
const int64_t ne3 = gf->nodes[i]->ne[3];
const uint64_t nb0 = gf->nodes[i]->nb[0];
const uint64_t nb1 = gf->nodes[i]->nb[1];
const uint64_t nb2 = gf->nodes[i]->nb[2];
const uint64_t nb3 = gf->nodes[i]->nb[3];
const int n_past = ((int32_t *) gf->nodes[i]->src1->data)[0]; // TODO: TMP !!!!!
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];
printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
GGML_ASSERT(false);
@ -486,7 +559,7 @@ int llama_mtl_eval(
{
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
fprintf(stderr, "%s: time elapsed = %f\n", __func__, time_elapsed);
fprintf(stderr, "%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
}
// TODO

View file

@ -210,3 +210,58 @@ kernel void kernel_mul_mat_q4_0(
dst[r1*ne0 + r0] = sum[0];
}
}
kernel void kernel_rope(
device const void * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & mode,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i3 = tpig[2];
const int64_t i2 = tpig[1];
const int64_t i1 = tpig[0];
const bool is_neox = mode & 2;
const float theta_scale = pow(10000.0, -2.0f/n_dims);
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
float theta = (float)p;
if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
theta *= theta_scale;
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
// TODO: implement
}
}

View file

@ -1270,19 +1270,20 @@ static bool llama_eval_internal(
// self-attention
{
auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(x, "mtl-check");
}
//auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
// compute Q and K and RoPE them
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
ggml_set_name(Qcur, "Qcur");
ggml_set_name(Kcur, "Kcur");
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(Qcur, "mtl-check");
}
// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
@ -1437,7 +1438,14 @@ static bool llama_eval_internal(
//ggml_graph_compute (ctx0, &gf);
// lets export a smaller graph to get things rolling -- baby steps first
ggml_build_forward_expand(&gf_export, ggml_get_tensor(ctx0, "mtl-check"));
{
struct ggml_tensor * t = ggml_get_tensor(ctx0, "mtl-check");
if (!t) {
fprintf(stderr, "%s: failed to find tensor 'mtl-check'\n", __func__);
exit(1);
}
ggml_build_forward_expand(&gf_export, t);
}
// print
{