From ec881156f6ad56ed552c06698083f5a263ff0a6d Mon Sep 17 00:00:00 2001 From: xaedes Date: Mon, 15 May 2023 14:42:24 +0200 Subject: [PATCH] improve ggml_out_prod performance - change iteration order (>15s -> 10s runtime) - parallelize over one more dimension: over dst matrix rows (10s -> <5s runtime) --- ggml.c | 51 +++++++++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/ggml.c b/ggml.c index 77b654809..52a9c9bcc 100644 --- a/ggml.c +++ b/ggml.c @@ -9917,51 +9917,50 @@ static void ggml_compute_forward_out_prod_f32( return; } - // parallelize by last two dimensions + // parallelize by last three dimensions - // total parallel in src0 - const int64_t np = ne02*ne03; + // total rows in dst + const int64_t nr = ne1*ne2*ne3; - // per thread - const int64_t dp = (np + nth - 1)/nth; + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; - // range for this thread - const int64_t ip0 = dp*ith; - const int64_t ip1 = MIN(ip0 + dp, np); + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); // dst[:,:,:,:] = 0 // for i2,i3: - // for i01: - // for i1: + // for i1: + // for i01: // for i0: // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] - for (int64_t ip = ip0; ip < ip1; ++ip) { - // src0 indices - const int64_t i3 = ip/ne02; - const int64_t i2 = ip - i3*ne02; - + for (int64_t ir = ir0; ir < ir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + const int64_t i02 = i2; const int64_t i03 = i3; + const int64_t i10 = i1; const int64_t i12 = i2; const int64_t i13 = i3; + for (int64_t i01 = 0; i01 < ne01; ++i01) { const int64_t i11 = i01; - for (int64_t i1 = 0; i1 < ne1; ++i1) { - const int64_t i10 = i1; + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - ggml_vec_mad_f32(ne0, d, s0, *s1); - // for (int64_t i0 = 0; i0 < ne0; ++i0) { - // d[i0] += s0[i0] * s1[i1]; - // } - } + ggml_vec_mad_f32(ne0, d, s0, *s1); + // for (int64_t i0 = 0; i0 < ne0; ++i0) { + // d[i0] += s0[i0] * s1[i1]; + // } } }