mtl : add cpy kernel + handle view ops
This commit is contained in:
parent
94ea9e7bfe
commit
948fcfde7e
3 changed files with 191 additions and 21 deletions
|
@ -44,6 +44,9 @@ struct ggml_mtl_context {
|
||||||
|
|
||||||
id<MTLFunction> function_rope;
|
id<MTLFunction> function_rope;
|
||||||
id<MTLComputePipelineState> pipeline_rope;
|
id<MTLComputePipelineState> pipeline_rope;
|
||||||
|
|
||||||
|
id<MTLFunction> function_cpy_f32_f16;
|
||||||
|
id<MTLComputePipelineState> pipeline_cpy_f32_f16;
|
||||||
};
|
};
|
||||||
|
|
||||||
// MSL code
|
// MSL code
|
||||||
|
@ -155,6 +158,10 @@ struct ggml_mtl_context * llama_mtl_init(
|
||||||
ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"];
|
ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"];
|
||||||
ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil];
|
ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil];
|
||||||
fprintf(stderr, "%s: loaded kernel_rope: %p\n", __func__, (void *) ctx->pipeline_rope);
|
fprintf(stderr, "%s: loaded kernel_rope: %p\n", __func__, (void *) ctx->pipeline_rope);
|
||||||
|
|
||||||
|
ctx->function_cpy_f32_f16 = [ctx->library newFunctionWithName:@"kernel_cpy_f32_f16"];
|
||||||
|
ctx->pipeline_cpy_f32_f16 = [ctx->device newComputePipelineStateWithFunction:ctx->function_cpy_f32_f16 error:nil];
|
||||||
|
fprintf(stderr, "%s: loaded kernel_cpy_f32_f16: %p\n", __func__, (void *) ctx->pipeline_cpy_f32_f16);
|
||||||
}
|
}
|
||||||
|
|
||||||
// MTLBuffer approach
|
// MTLBuffer approach
|
||||||
|
@ -258,6 +265,7 @@ int llama_mtl_eval(
|
||||||
|
|
||||||
switch (gf->nodes[i]->op) {
|
switch (gf->nodes[i]->op) {
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
|
case GGML_OP_VIEW:
|
||||||
case GGML_OP_TRANSPOSE:
|
case GGML_OP_TRANSPOSE:
|
||||||
{
|
{
|
||||||
// noop
|
// noop
|
||||||
|
@ -527,6 +535,76 @@ int llama_mtl_eval(
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_CPY:
|
||||||
|
{
|
||||||
|
if (encoder == nil) {
|
||||||
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
|
}
|
||||||
|
|
||||||
|
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||||
|
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||||
|
|
||||||
|
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
||||||
|
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
|
||||||
|
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
|
||||||
|
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
|
||||||
|
|
||||||
|
const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
|
||||||
|
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
||||||
|
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
|
||||||
|
const uint64_t nb03 = gf->nodes[i]->src0->nb[3];
|
||||||
|
|
||||||
|
const int64_t ne0 = gf->nodes[i]->ne[0];
|
||||||
|
const int64_t ne1 = gf->nodes[i]->ne[1];
|
||||||
|
const int64_t ne2 = gf->nodes[i]->ne[2];
|
||||||
|
const int64_t ne3 = gf->nodes[i]->ne[3];
|
||||||
|
|
||||||
|
const uint64_t nb0 = gf->nodes[i]->nb[0];
|
||||||
|
const uint64_t nb1 = gf->nodes[i]->nb[1];
|
||||||
|
const uint64_t nb2 = gf->nodes[i]->nb[2];
|
||||||
|
const uint64_t nb3 = gf->nodes[i]->nb[3];
|
||||||
|
|
||||||
|
const enum ggml_type src0t = gf->nodes[i]->src0->type;
|
||||||
|
const enum ggml_type dstt = gf->nodes[i]->type;
|
||||||
|
|
||||||
|
printf("cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
|
||||||
|
printf("cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
|
||||||
|
printf("cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
|
||||||
|
printf("cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3);
|
||||||
|
printf("cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt));
|
||||||
|
|
||||||
|
switch (src0t) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
switch (dstt) {
|
||||||
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
||||||
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
|
};
|
||||||
|
} break;
|
||||||
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||||
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||||
|
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
||||||
|
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
||||||
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
||||||
|
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
||||||
|
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
||||||
|
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
||||||
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
||||||
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
||||||
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
||||||
|
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
||||||
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
||||||
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||||
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -568,9 +646,10 @@ int llama_mtl_eval(
|
||||||
|
|
||||||
{
|
{
|
||||||
struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check");
|
struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check");
|
||||||
float * data = (float *) ctx->out.contents;
|
if (t->type == GGML_TYPE_F32) {
|
||||||
|
const const float * data = (float *) ctx->out.contents;
|
||||||
printf("data: ");
|
printf("data: ");
|
||||||
int n = t->ne[0];
|
int n = ggml_nelements(t);
|
||||||
if (n > 10) {
|
if (n > 10) {
|
||||||
n = 10;
|
n = 10;
|
||||||
}
|
}
|
||||||
|
@ -583,6 +662,25 @@ int llama_mtl_eval(
|
||||||
sum += data[i];
|
sum += data[i];
|
||||||
}
|
}
|
||||||
printf("sum: %f\n", sum);
|
printf("sum: %f\n", sum);
|
||||||
|
} else if (t->type == GGML_TYPE_F16) {
|
||||||
|
const ggml_fp16_t * data = (const ggml_fp16_t *) ctx->out.contents;
|
||||||
|
printf("data: ");
|
||||||
|
int n = ggml_nelements(t);
|
||||||
|
if (n > 10) {
|
||||||
|
n = 10;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
printf("%f ", ggml_fp16_to_fp32(data[i]));
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
double sum = 0.0;
|
||||||
|
for (int i = 0; i < ggml_nelements(t); i++) {
|
||||||
|
sum += ggml_fp16_to_fp32(data[i]);
|
||||||
|
}
|
||||||
|
printf("sum: %f\n", sum);
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false && "not implemented");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
@ -265,3 +265,45 @@ kernel void kernel_rope(
|
||||||
// TODO: implement
|
// TODO: implement
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_cpy_f32_f16(
|
||||||
|
device const float * src0,
|
||||||
|
device half * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint64_t & nb2,
|
||||||
|
constant uint64_t & nb3,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = tgpig[2];
|
||||||
|
const int64_t i02 = tgpig[1];
|
||||||
|
const int64_t i01 = tgpig[0];
|
||||||
|
|
||||||
|
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||||
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||||
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||||
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||||
|
|
||||||
|
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||||
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||||
|
|
||||||
|
dst_data[i00] = src[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
44
llama.cpp
44
llama.cpp
|
@ -1283,18 +1283,21 @@ static bool llama_eval_internal(
|
||||||
{
|
{
|
||||||
// compute the transposed [N, n_embd] V matrix
|
// compute the transposed [N, n_embd] V matrix
|
||||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
|
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
|
||||||
// TODO: TMP !!!!
|
|
||||||
if (il == 0) {
|
|
||||||
ggml_set_name(Vcur, "mtl-check");
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
|
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
|
||||||
( n_ctx)*ggml_element_size(kv_self.v),
|
( n_ctx)*ggml_element_size(kv_self.v),
|
||||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
||||||
|
|
||||||
|
struct ggml_tensor * t = ggml_cpy(ctx0, Kcur, k);
|
||||||
|
// TODO: TMP !!!!
|
||||||
|
if (il == 0) {
|
||||||
|
ggml_set_name(t, "mtl-check");
|
||||||
|
}
|
||||||
|
|
||||||
// important: storing RoPE-ed version of K in the KV cache!
|
// important: storing RoPE-ed version of K in the KV cache!
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
//ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||||
|
ggml_build_forward_expand(&gf, t);
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1448,7 +1451,7 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
// print
|
// print
|
||||||
{
|
{
|
||||||
auto print_t = [&](struct ggml_tensor * t) {
|
auto print_t_f32 = [&](struct ggml_tensor * t) {
|
||||||
float * data = (float *)t->data;
|
float * data = (float *)t->data;
|
||||||
printf("data: ");
|
printf("data: ");
|
||||||
for (int i = 0; i < std::min((int) t->ne[0], 10); i++) {
|
for (int i = 0; i < std::min((int) t->ne[0], 10); i++) {
|
||||||
|
@ -1461,9 +1464,36 @@ static bool llama_eval_internal(
|
||||||
}
|
}
|
||||||
printf("sum: %f\n", sum);
|
printf("sum: %f\n", sum);
|
||||||
};
|
};
|
||||||
|
auto print_t_f16 = [&](struct ggml_tensor * t) {
|
||||||
|
ggml_fp16_t * data = (ggml_fp16_t *)t->data;
|
||||||
|
printf("data: ");
|
||||||
|
for (int i = 0; i < std::min((int) t->ne[0], 10); i++) {
|
||||||
|
printf("%f ", ggml_fp16_to_fp32(data[i]));
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
double sum = 0.0;
|
||||||
|
for (int i = 0; i < ggml_nelements(t); i++) {
|
||||||
|
sum += ggml_fp16_to_fp32(data[i]);
|
||||||
|
}
|
||||||
|
printf("sum: %f\n", sum);
|
||||||
|
};
|
||||||
|
|
||||||
ggml_graph_compute(ctx0, &gf_export);
|
ggml_graph_compute(ctx0, &gf_export);
|
||||||
print_t(ggml_get_tensor(ctx0, "mtl-check"));
|
|
||||||
|
{
|
||||||
|
auto * t = ggml_get_tensor(ctx0, "mtl-check");
|
||||||
|
switch (t->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
print_t_f32(t);
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
print_t_f16(t);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
fprintf(stderr, "%s: unsupported type\n", __func__);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cgraph_fname) {
|
if (cgraph_fname) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue