mtl : optimize rms_norm and soft_max kernels

This commit is contained in:
Georgi Gerganov 2023-06-01 22:51:42 +03:00
parent 9665429e94
commit f0196a7e7a
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 166 additions and 47 deletions

View file

@ -41,13 +41,15 @@ int main(int argc, char ** argv) {
// TODO: tmp to match the input used when creating the cgraph // TODO: tmp to match the input used when creating the cgraph
{ {
const int n_past = 128; const int n_batch = 1;
const int n_batch = 32; const int n_past = 512 - n_batch;
const std::vector<int> tmp(n_batch, 1); // BOS const std::vector<int> tmp(n_batch, 1); // BOS
// the actual inference happens here // the actual inference happens here
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past); for (int i = 0; i < 10; ++i) {
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
}
} }
llama_mtl_free(ctx_mtl); llama_mtl_free(ctx_mtl);

View file

@ -429,14 +429,17 @@ int llama_mtl_eval(
const int64_t ne02 = gf->nodes[i]->src0->ne[2]; const int64_t ne02 = gf->nodes[i]->src0->ne[2];
const int64_t ne03 = gf->nodes[i]->src0->ne[3]; const int64_t ne03 = gf->nodes[i]->src0->ne[3];
const int nth = 32;
[encoder setComputePipelineState:ctx->pipeline_soft_max]; [encoder setComputePipelineState:ctx->pipeline_soft_max];
[encoder setBuffer:id_src offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
{ {
@ -494,10 +497,10 @@ int llama_mtl_eval(
const enum ggml_type src1t = gf->nodes[i]->src1->type; const enum ggml_type src1t = gf->nodes[i]->src1->type;
const enum ggml_type dstt = gf->nodes[i]->type; const enum ggml_type dstt = gf->nodes[i]->type;
printf("mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02); fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02);
printf("mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12); fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12);
printf("mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2); fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
printf("mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt)); fprintf(stderr, "mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt));
GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne00 == ne10);
GGML_ASSERT(ne02 == ne12); GGML_ASSERT(ne02 == ne12);
@ -599,16 +602,19 @@ int llama_mtl_eval(
const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const float eps = 1e-6f; const float eps = 1e-6f;
const int nth = 32;
[encoder setComputePipelineState:ctx->pipeline_rms_norm]; [encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4]; [encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(gf->nodes[i]->src0); const int64_t nrows = ggml_nrows(gf->nodes[i]->src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
@ -643,9 +649,9 @@ int llama_mtl_eval(
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1]; const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2]; const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];
printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode); fprintf(stderr, "rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
[encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -704,11 +710,13 @@ int llama_mtl_eval(
const enum ggml_type src0t = gf->nodes[i]->src0->type; const enum ggml_type src0t = gf->nodes[i]->src0->type;
const enum ggml_type dstt = gf->nodes[i]->type; const enum ggml_type dstt = gf->nodes[i]->type;
printf("cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); const int nth = 32;
printf("cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
printf("cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
printf("cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3); fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
printf("cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt)); fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3);
fprintf(stderr, "cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt));
switch (src0t) { switch (src0t) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
@ -741,7 +749,7 @@ int llama_mtl_eval(
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
default: default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
@ -764,8 +772,6 @@ int llama_mtl_eval(
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, out, &offs_src0); id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, out, &offs_src0);
id<MTLBuffer> id_dst = ctx->out; id<MTLBuffer> id_dst = ctx->out;
printf("XXXXX n = %d\n", ggml_nelements(out));
id<MTLBlitCommandEncoder> encoder_blit = [command_buffer blitCommandEncoder]; id<MTLBlitCommandEncoder> encoder_blit = [command_buffer blitCommandEncoder];
[encoder_blit copyFromBuffer:id_src sourceOffset:offs_src0 toBuffer:id_dst destinationOffset:0 size:ggml_nbytes(out)]; [encoder_blit copyFromBuffer:id_src sourceOffset:offs_src0 toBuffer:id_dst destinationOffset:0 size:ggml_nbytes(out)];
[encoder_blit endEncoding]; [encoder_blit endEncoding];
@ -776,12 +782,29 @@ int llama_mtl_eval(
{ {
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime]; const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
fprintf(stderr, "%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0); printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
} }
// TODO // TODO
const float * logits = ctx->out.contents; const float * logits = ctx->out.contents;
printf("logits: ");
for (int i = 0; i < 100; i++) {
printf("%8.4f ", logits[i]);
}
printf("\n");
double sum = 0.0;
int imax = 0;
double vmax = -INFINITY;
for (int i = 0; i < 32000; i++) {
sum += (double) logits[i];
if (logits[i] > vmax) {
vmax = logits[i];
imax = i;
}
}
printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
//{ //{
// struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check"); // struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check");
// if (t->type == GGML_TYPE_F32) { // if (t->type == GGML_TYPE_F32) {

View file

@ -87,25 +87,80 @@ kernel void kernel_soft_max(
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02, constant int64_t & ne02,
uint3 tpig[[thread_position_in_grid]]) { threadgroup float * buf [[threadgroup(0)]],
const int64_t i03 = tpig[2]; uint3 tgpig[[threadgroup_position_in_grid]],
const int64_t i02 = tpig[1]; uint3 tpitg[[thread_position_in_threadgroup]],
const int64_t i01 = tpig[0]; uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
float max = 0.0f; //float max = 0.0f;
for (int i = 0; i < ne00; i++) { //for (int i = 0; i < ne00; i++) {
max = MAX(max, psrc0[i]); // max = MAX(max, psrc0[i]);
//}
//float sum = 0.0f;
//for (int i = 0; i < ne00; i++) {
// pdst[i] = exp(psrc0[i] - max);
// sum += pdst[i];
//}
//for (int i = 0; i < ne00; i++) {
// pdst[i] /= sum;
//}
// parallel max
buf[tpitg[0]] = -INFINITY;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
} }
float sum = 0.0f;
for (int i = 0; i < ne00; i++) { // reduce
pdst[i] = exp(psrc0[i] - max); threadgroup_barrier(mem_flags::mem_threadgroup);
sum += pdst[i]; for (uint i = ntg[0]/2; i > 0; i /= 2) {
if (tpitg[0] < i) {
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
} }
for (int i = 0; i < ne00; i++) {
pdst[i] /= sum; // broadcast
if (tpitg[0] == 0) {
buf[0] = buf[0];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
const float max = buf[0];
// parallel sum
buf[tpitg[0]] = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
buf[tpitg[0]] += exp(psrc0[i00] - max);
}
// reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg[0]/2; i > 0; i /= 2) {
if (tpitg[0] < i) {
buf[tpitg[0]] += buf[tpitg[0] + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// broadcast
if (tpitg[0] == 0) {
buf[0] = buf[0];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
const float sum = buf[0];
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
pdst[i00] = exp(psrc0[i00] - max) / sum;
} }
} }
@ -149,19 +204,39 @@ kernel void kernel_rms_norm(
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
constant float & eps, constant float & eps,
uint tpig[[thread_position_in_grid]]) { threadgroup float * sum [[threadgroup(0)]],
device const float * x = (device const float *) ((device const char *) src0 + tpig*nb01); uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
float sum = 0.0f; // parallel sum
for (int i00 = 0; i00 < ne00; i00++) { sum[tpitg] = 0.0f;
sum += x[i00] * x[i00]; for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
sum[tpitg] += x[i00] * x[i00];
} }
const float mean = sum/ne00; // reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg/2; i > 0; i /= 2) {
if (tpitg < i) {
sum[tpitg] += sum[tpitg + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// broadcast
if (tpitg == 0) {
sum[0] /= ne00;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
const float mean = sum[0];
const float scale = 1.0f/sqrt(mean + eps); const float scale = 1.0f/sqrt(mean + eps);
device float * y = dst + tpig*ne00; device float * y = dst + tgpig*ne00;
for (int i00 = 0; i00 < ne00; i00++) { for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
y[i00] = x[i00] * scale; y[i00] = x[i00] * scale;
} }
} }

4
ggml.c
View file

@ -14647,8 +14647,8 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
} }
void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
assert(cgraph->work == NULL); //assert(cgraph->work == NULL);
assert(cgraph->work_size == 0); //assert(cgraph->work_size == 0);
uint64_t size_eval = 0; uint64_t size_eval = 0;

View file

@ -1506,6 +1506,25 @@ static bool llama_eval_internal(
if (cgraph_fname) { if (cgraph_fname) {
ggml_graph_export(&gf, cgraph_fname); ggml_graph_export(&gf, cgraph_fname);
float * logits = (float *) ggml_get_data(inpL);
printf("logits: ");
for (int i = 0; i < 10; i++) {
printf("%8.4f ", logits[i]);
}
printf("\n");
double sum = 0.0;
int imax = 0;
double vmax = -INFINITY;
for (int i = 0; i < 32000; i++) {
sum += (double) logits[i];
if (logits[i] > vmax) {
vmax = logits[i];
imax = i;
}
}
printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
} }
#ifdef GGML_PERF #ifdef GGML_PERF
@ -3002,11 +3021,11 @@ int llama_eval(
int llama_eval_export(struct llama_context * ctx, const char * fname) { int llama_eval_export(struct llama_context * ctx, const char * fname) {
// these values determine the maximum inference sizes of the exported computation graph // these values determine the maximum inference sizes of the exported computation graph
// TODO: TMP !!! // TODO: need to increase buffers to support the full context
//const int n_ctx = ctx->model.hparams.n_ctx; //const int n_ctx = ctx->model.hparams.n_ctx;
//const int n_batch = 512; //const int n_batch = 512;
const int n_ctx = 128; const int n_batch = 1;
const int n_batch = 32; const int n_ctx = 512 - n_batch;
const std::vector<llama_token> tmp(n_batch, llama_token_bos()); const std::vector<llama_token> tmp(n_batch, llama_token_bos());