mtl : add rope kernel
This commit is contained in:
parent
6af6a05663
commit
1213af76ce
3 changed files with 145 additions and 9 deletions
|
@ -41,6 +41,9 @@ struct ggml_mtl_context {
|
|||
|
||||
id<MTLFunction> function_mul_mat_q4_0;
|
||||
id<MTLComputePipelineState> pipeline_mul_mat_q4_0;
|
||||
|
||||
id<MTLFunction> function_rope;
|
||||
id<MTLComputePipelineState> pipeline_rope;
|
||||
};
|
||||
|
||||
// MSL code
|
||||
|
@ -148,6 +151,10 @@ struct ggml_mtl_context * llama_mtl_init(
|
|||
ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"];
|
||||
ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil];
|
||||
fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0);
|
||||
|
||||
ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"];
|
||||
ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil];
|
||||
fprintf(stderr, "%s: loaded kernel_rope: %p\n", __func__, (void *) ctx->pipeline_rope);
|
||||
}
|
||||
|
||||
// MTLBuffer approach
|
||||
|
@ -250,6 +257,10 @@ int llama_mtl_eval(
|
|||
fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||
|
||||
switch (gf->nodes[i]->op) {
|
||||
case GGML_OP_RESHAPE:
|
||||
{
|
||||
// noop
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
|
@ -453,6 +464,68 @@ int llama_mtl_eval(
|
|||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||
|
||||
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
||||
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
|
||||
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
|
||||
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
|
||||
|
||||
const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
|
||||
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
||||
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
|
||||
const uint64_t nb03 = gf->nodes[i]->src0->nb[3];
|
||||
|
||||
const int64_t ne0 = gf->nodes[i]->ne[0];
|
||||
const int64_t ne1 = gf->nodes[i]->ne[1];
|
||||
const int64_t ne2 = gf->nodes[i]->ne[2];
|
||||
const int64_t ne3 = gf->nodes[i]->ne[3];
|
||||
|
||||
const uint64_t nb0 = gf->nodes[i]->nb[0];
|
||||
const uint64_t nb1 = gf->nodes[i]->nb[1];
|
||||
const uint64_t nb2 = gf->nodes[i]->nb[2];
|
||||
const uint64_t nb3 = gf->nodes[i]->nb[3];
|
||||
|
||||
const int n_past = ((int32_t *) gf->nodes[i]->src1->data)[0]; // TODO: TMP !!!!!
|
||||
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
|
||||
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];
|
||||
|
||||
printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
|
||||
printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
|
||||
printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_rope];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
|
||||
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
|
||||
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||
GGML_ASSERT(false);
|
||||
|
@ -486,7 +559,7 @@ int llama_mtl_eval(
|
|||
|
||||
{
|
||||
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
|
||||
fprintf(stderr, "%s: time elapsed = %f\n", __func__, time_elapsed);
|
||||
fprintf(stderr, "%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
|
||||
}
|
||||
|
||||
// TODO
|
||||
|
|
|
@ -210,3 +210,58 @@ kernel void kernel_mul_mat_q4_0(
|
|||
dst[r1*ne0 + r0] = sum[0];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope(
|
||||
device const void * src0,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
constant int & n_past,
|
||||
constant int & n_dims,
|
||||
constant int & mode,
|
||||
uint3 tpig[[thread_position_in_grid]]) {
|
||||
const int64_t i3 = tpig[2];
|
||||
const int64_t i2 = tpig[1];
|
||||
const int64_t i1 = tpig[0];
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const float theta_scale = pow(10000.0, -2.0f/n_dims);
|
||||
|
||||
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
||||
|
||||
float theta = (float)p;
|
||||
|
||||
if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float cos_theta = cos(theta);
|
||||
const float sin_theta = sin(theta);
|
||||
|
||||
theta *= theta_scale;
|
||||
|
||||
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[1];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
} else {
|
||||
// TODO: implement
|
||||
}
|
||||
}
|
||||
|
|
24
llama.cpp
24
llama.cpp
|
@ -1270,19 +1270,20 @@ static bool llama_eval_internal(
|
|||
|
||||
// self-attention
|
||||
{
|
||||
auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
// TODO: TMP !!!!
|
||||
if (il == 0) {
|
||||
ggml_set_name(x, "mtl-check");
|
||||
}
|
||||
//auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
ggml_set_name(Qcur, "Qcur");
|
||||
ggml_set_name(Kcur, "Kcur");
|
||||
|
||||
// TODO: TMP !!!!
|
||||
if (il == 0) {
|
||||
ggml_set_name(Qcur, "mtl-check");
|
||||
}
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
// compute the transposed [N, n_embd] V matrix
|
||||
|
@ -1437,7 +1438,14 @@ static bool llama_eval_internal(
|
|||
//ggml_graph_compute (ctx0, &gf);
|
||||
|
||||
// lets export a smaller graph to get things rolling -- baby steps first
|
||||
ggml_build_forward_expand(&gf_export, ggml_get_tensor(ctx0, "mtl-check"));
|
||||
{
|
||||
struct ggml_tensor * t = ggml_get_tensor(ctx0, "mtl-check");
|
||||
if (!t) {
|
||||
fprintf(stderr, "%s: failed to find tensor 'mtl-check'\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
ggml_build_forward_expand(&gf_export, t);
|
||||
}
|
||||
|
||||
// print
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue