mtl : add cpy kernel + handle view ops

This commit is contained in:
Georgi Gerganov 2023-06-01 19:21:28 +03:00
parent 94ea9e7bfe
commit 948fcfde7e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 191 additions and 21 deletions

View file

@ -44,6 +44,9 @@ struct ggml_mtl_context {
id<MTLFunction> function_rope; id<MTLFunction> function_rope;
id<MTLComputePipelineState> pipeline_rope; id<MTLComputePipelineState> pipeline_rope;
id<MTLFunction> function_cpy_f32_f16;
id<MTLComputePipelineState> pipeline_cpy_f32_f16;
}; };
// MSL code // MSL code
@ -155,6 +158,10 @@ struct ggml_mtl_context * llama_mtl_init(
ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"]; ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"];
ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil]; ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil];
fprintf(stderr, "%s: loaded kernel_rope: %p\n", __func__, (void *) ctx->pipeline_rope); fprintf(stderr, "%s: loaded kernel_rope: %p\n", __func__, (void *) ctx->pipeline_rope);
ctx->function_cpy_f32_f16 = [ctx->library newFunctionWithName:@"kernel_cpy_f32_f16"];
ctx->pipeline_cpy_f32_f16 = [ctx->device newComputePipelineStateWithFunction:ctx->function_cpy_f32_f16 error:nil];
fprintf(stderr, "%s: loaded kernel_cpy_f32_f16: %p\n", __func__, (void *) ctx->pipeline_cpy_f32_f16);
} }
// MTLBuffer approach // MTLBuffer approach
@ -258,6 +265,7 @@ int llama_mtl_eval(
switch (gf->nodes[i]->op) { switch (gf->nodes[i]->op) {
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE: case GGML_OP_TRANSPOSE:
{ {
// noop // noop
@ -527,6 +535,76 @@ int llama_mtl_eval(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_CPY:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
const uint64_t nb03 = gf->nodes[i]->src0->nb[3];
const int64_t ne0 = gf->nodes[i]->ne[0];
const int64_t ne1 = gf->nodes[i]->ne[1];
const int64_t ne2 = gf->nodes[i]->ne[2];
const int64_t ne3 = gf->nodes[i]->ne[3];
const uint64_t nb0 = gf->nodes[i]->nb[0];
const uint64_t nb1 = gf->nodes[i]->nb[1];
const uint64_t nb2 = gf->nodes[i]->nb[2];
const uint64_t nb3 = gf->nodes[i]->nb[3];
const enum ggml_type src0t = gf->nodes[i]->src0->type;
const enum ggml_type dstt = gf->nodes[i]->type;
printf("cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
printf("cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
printf("cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
printf("cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3);
printf("cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt));
switch (src0t) {
case GGML_TYPE_F32:
{
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
default: default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
GGML_ASSERT(false); GGML_ASSERT(false);
@ -568,21 +646,41 @@ int llama_mtl_eval(
{ {
struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check"); struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check");
float * data = (float *) ctx->out.contents; if (t->type == GGML_TYPE_F32) {
printf("data: "); const const float * data = (float *) ctx->out.contents;
int n = t->ne[0]; printf("data: ");
if (n > 10) { int n = ggml_nelements(t);
n = 10; if (n > 10) {
n = 10;
}
for (int i = 0; i < n; i++) {
printf("%f ", data[i]);
}
printf("\n");
double sum = 0.0;
for (int i = 0; i < ggml_nelements(t); i++) {
sum += data[i];
}
printf("sum: %f\n", sum);
} else if (t->type == GGML_TYPE_F16) {
const ggml_fp16_t * data = (const ggml_fp16_t *) ctx->out.contents;
printf("data: ");
int n = ggml_nelements(t);
if (n > 10) {
n = 10;
}
for (int i = 0; i < n; i++) {
printf("%f ", ggml_fp16_to_fp32(data[i]));
}
printf("\n");
double sum = 0.0;
for (int i = 0; i < ggml_nelements(t); i++) {
sum += ggml_fp16_to_fp32(data[i]);
}
printf("sum: %f\n", sum);
} else {
GGML_ASSERT(false && "not implemented");
} }
for (int i = 0; i < n; i++) {
printf("%f ", data[i]);
}
printf("\n");
double sum = 0.0;
for (int i = 0; i < ggml_nelements(t); i++) {
sum += data[i];
}
printf("sum: %f\n", sum);
} }
return 0; return 0;

View file

@ -265,3 +265,45 @@ kernel void kernel_rope(
// TODO: implement // TODO: implement
} }
} }
kernel void kernel_cpy_f32_f16(
device const float * src0,
device half * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}

View file

@ -1283,18 +1283,21 @@ static bool llama_eval_internal(
{ {
// compute the transposed [N, n_embd] V matrix // compute the transposed [N, n_embd] V matrix
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N)); struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(Vcur, "mtl-check");
}
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
( n_ctx)*ggml_element_size(kv_self.v), ( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
struct ggml_tensor * t = ggml_cpy(ctx0, Kcur, k);
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(t, "mtl-check");
}
// important: storing RoPE-ed version of K in the KV cache! // important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); //ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, t);
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
} }
@ -1448,7 +1451,7 @@ static bool llama_eval_internal(
// print // print
{ {
auto print_t = [&](struct ggml_tensor * t) { auto print_t_f32 = [&](struct ggml_tensor * t) {
float * data = (float *)t->data; float * data = (float *)t->data;
printf("data: "); printf("data: ");
for (int i = 0; i < std::min((int) t->ne[0], 10); i++) { for (int i = 0; i < std::min((int) t->ne[0], 10); i++) {
@ -1461,9 +1464,36 @@ static bool llama_eval_internal(
} }
printf("sum: %f\n", sum); printf("sum: %f\n", sum);
}; };
auto print_t_f16 = [&](struct ggml_tensor * t) {
ggml_fp16_t * data = (ggml_fp16_t *)t->data;
printf("data: ");
for (int i = 0; i < std::min((int) t->ne[0], 10); i++) {
printf("%f ", ggml_fp16_to_fp32(data[i]));
}
printf("\n");
double sum = 0.0;
for (int i = 0; i < ggml_nelements(t); i++) {
sum += ggml_fp16_to_fp32(data[i]);
}
printf("sum: %f\n", sum);
};
ggml_graph_compute(ctx0, &gf_export); ggml_graph_compute(ctx0, &gf_export);
print_t(ggml_get_tensor(ctx0, "mtl-check"));
{
auto * t = ggml_get_tensor(ctx0, "mtl-check");
switch (t->type) {
case GGML_TYPE_F32:
print_t_f32(t);
break;
case GGML_TYPE_F16:
print_t_f16(t);
break;
default:
fprintf(stderr, "%s: unsupported type\n", __func__);
exit(1);
}
}
} }
if (cgraph_fname) { if (cgraph_fname) {