mtl : optimize rms_norm and soft_max kernels

This commit is contained in:
Georgi Gerganov 2023-06-01 22:51:42 +03:00
parent 9665429e94
commit f0196a7e7a
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 166 additions and 47 deletions

View file

@ -41,13 +41,15 @@ int main(int argc, char ** argv) {
// TODO: tmp to match the input used when creating the cgraph
{
const int n_past = 128;
const int n_batch = 32;
const int n_batch = 1;
const int n_past = 512 - n_batch;
const std::vector<int> tmp(n_batch, 1); // BOS
// the actual inference happens here
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
for (int i = 0; i < 10; ++i) {
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
}
}
llama_mtl_free(ctx_mtl);

View file

@ -429,14 +429,17 @@ int llama_mtl_eval(
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
const int nth = 32;
[encoder setComputePipelineState:ctx->pipeline_soft_max];
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_DIAG_MASK_INF:
{
@ -494,10 +497,10 @@ int llama_mtl_eval(
const enum ggml_type src1t = gf->nodes[i]->src1->type;
const enum ggml_type dstt = gf->nodes[i]->type;
printf("mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02);
printf("mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12);
printf("mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
printf("mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt));
fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02);
fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12);
fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
fprintf(stderr, "mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt));
GGML_ASSERT(ne00 == ne10);
GGML_ASSERT(ne02 == ne12);
@ -599,16 +602,19 @@ int llama_mtl_eval(
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const float eps = 1e-6f;
const int nth = 32;
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(gf->nodes[i]->src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ROPE:
{
@ -643,9 +649,9 @@ int llama_mtl_eval(
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];
printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
fprintf(stderr, "rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -704,11 +710,13 @@ int llama_mtl_eval(
const enum ggml_type src0t = gf->nodes[i]->src0->type;
const enum ggml_type dstt = gf->nodes[i]->type;
printf("cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
printf("cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
printf("cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
printf("cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3);
printf("cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt));
const int nth = 32;
fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3);
fprintf(stderr, "cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt));
switch (src0t) {
case GGML_TYPE_F32:
@ -741,7 +749,7 @@ int llama_mtl_eval(
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
@ -764,8 +772,6 @@ int llama_mtl_eval(
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, out, &offs_src0);
id<MTLBuffer> id_dst = ctx->out;
printf("XXXXX n = %d\n", ggml_nelements(out));
id<MTLBlitCommandEncoder> encoder_blit = [command_buffer blitCommandEncoder];
[encoder_blit copyFromBuffer:id_src sourceOffset:offs_src0 toBuffer:id_dst destinationOffset:0 size:ggml_nbytes(out)];
[encoder_blit endEncoding];
@ -776,12 +782,29 @@ int llama_mtl_eval(
{
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
fprintf(stderr, "%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
}
// TODO
const float * logits = ctx->out.contents;
printf("logits: ");
for (int i = 0; i < 100; i++) {
printf("%8.4f ", logits[i]);
}
printf("\n");
double sum = 0.0;
int imax = 0;
double vmax = -INFINITY;
for (int i = 0; i < 32000; i++) {
sum += (double) logits[i];
if (logits[i] > vmax) {
vmax = logits[i];
imax = i;
}
}
printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
//{
// struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check");
// if (t->type == GGML_TYPE_F32) {

View file

@ -87,25 +87,80 @@ kernel void kernel_soft_max(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i03 = tpig[2];
const int64_t i02 = tpig[1];
const int64_t i01 = tpig[0];
threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
float max = 0.0f;
for (int i = 0; i < ne00; i++) {
max = MAX(max, psrc0[i]);
//float max = 0.0f;
//for (int i = 0; i < ne00; i++) {
// max = MAX(max, psrc0[i]);
//}
//float sum = 0.0f;
//for (int i = 0; i < ne00; i++) {
// pdst[i] = exp(psrc0[i] - max);
// sum += pdst[i];
//}
//for (int i = 0; i < ne00; i++) {
// pdst[i] /= sum;
//}
// parallel max
buf[tpitg[0]] = -INFINITY;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
}
float sum = 0.0f;
for (int i = 0; i < ne00; i++) {
pdst[i] = exp(psrc0[i] - max);
sum += pdst[i];
// reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg[0]/2; i > 0; i /= 2) {
if (tpitg[0] < i) {
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
for (int i = 0; i < ne00; i++) {
pdst[i] /= sum;
// broadcast
if (tpitg[0] == 0) {
buf[0] = buf[0];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
const float max = buf[0];
// parallel sum
buf[tpitg[0]] = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
buf[tpitg[0]] += exp(psrc0[i00] - max);
}
// reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg[0]/2; i > 0; i /= 2) {
if (tpitg[0] < i) {
buf[tpitg[0]] += buf[tpitg[0] + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// broadcast
if (tpitg[0] == 0) {
buf[0] = buf[0];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
const float sum = buf[0];
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
pdst[i00] = exp(psrc0[i00] - max) / sum;
}
}
@ -149,19 +204,39 @@ kernel void kernel_rms_norm(
constant int64_t & ne00,
constant uint64_t & nb01,
constant float & eps,
uint tpig[[thread_position_in_grid]]) {
device const float * x = (device const float *) ((device const char *) src0 + tpig*nb01);
threadgroup float * sum [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
float sum = 0.0f;
for (int i00 = 0; i00 < ne00; i00++) {
sum += x[i00] * x[i00];
// parallel sum
sum[tpitg] = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
sum[tpitg] += x[i00] * x[i00];
}
const float mean = sum/ne00;
// reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg/2; i > 0; i /= 2) {
if (tpitg < i) {
sum[tpitg] += sum[tpitg + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// broadcast
if (tpitg == 0) {
sum[0] /= ne00;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
const float mean = sum[0];
const float scale = 1.0f/sqrt(mean + eps);
device float * y = dst + tpig*ne00;
for (int i00 = 0; i00 < ne00; i00++) {
device float * y = dst + tgpig*ne00;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
y[i00] = x[i00] * scale;
}
}

4
ggml.c
View file

@ -14647,8 +14647,8 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
}
void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
assert(cgraph->work == NULL);
assert(cgraph->work_size == 0);
//assert(cgraph->work == NULL);
//assert(cgraph->work_size == 0);
uint64_t size_eval = 0;

View file

@ -1506,6 +1506,25 @@ static bool llama_eval_internal(
if (cgraph_fname) {
ggml_graph_export(&gf, cgraph_fname);
float * logits = (float *) ggml_get_data(inpL);
printf("logits: ");
for (int i = 0; i < 10; i++) {
printf("%8.4f ", logits[i]);
}
printf("\n");
double sum = 0.0;
int imax = 0;
double vmax = -INFINITY;
for (int i = 0; i < 32000; i++) {
sum += (double) logits[i];
if (logits[i] > vmax) {
vmax = logits[i];
imax = i;
}
}
printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
}
#ifdef GGML_PERF
@ -3002,11 +3021,11 @@ int llama_eval(
int llama_eval_export(struct llama_context * ctx, const char * fname) {
// these values determine the maximum inference sizes of the exported computation graph
// TODO: TMP !!!
// TODO: need to increase buffers to support the full context
//const int n_ctx = ctx->model.hparams.n_ctx;
//const int n_batch = 512;
const int n_ctx = 128;
const int n_batch = 32;
const int n_batch = 1;
const int n_ctx = 512 - n_batch;
const std::vector<llama_token> tmp(n_batch, llama_token_bos());