improve ggml_out_prod performance

- change iteration order (>15s -> 10s runtime)
- parallelize over one more dimension: over dst matrix rows (10s -> <5s runtime)
This commit is contained in:
xaedes 2023-05-15 14:42:24 +02:00
parent 19fb91899b
commit ec881156f6
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

33
ggml.c
View file

@ -9917,42 +9917,42 @@ static void ggml_compute_forward_out_prod_f32(
return;
}
// parallelize by last two dimensions
// parallelize by last three dimensions
// total parallel in src0
const int64_t np = ne02*ne03;
// total rows in dst
const int64_t nr = ne1*ne2*ne3;
// per thread
const int64_t dp = (np + nth - 1)/nth;
// rows per thread
const int64_t dr = (nr + nth - 1)/nth;
// range for this thread
const int64_t ip0 = dp*ith;
const int64_t ip1 = MIN(ip0 + dp, np);
// row range for this thread
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
// dst[:,:,:,:] = 0
// for i2,i3:
// for i01:
// for i1:
// for i01:
// for i0:
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
for (int64_t ip = ip0; ip < ip1; ++ip) {
// src0 indices
const int64_t i3 = ip/ne02;
const int64_t i2 = ip - i3*ne02;
for (int64_t ir = ir0; ir < ir1; ++ir) {
// dst indices
const int64_t i3 = ir/(ne2*ne1);
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int64_t i02 = i2;
const int64_t i03 = i3;
const int64_t i10 = i1;
const int64_t i12 = i2;
const int64_t i13 = i3;
for (int64_t i01 = 0; i01 < ne01; ++i01) {
const int64_t i11 = i01;
for (int64_t i1 = 0; i1 < ne1; ++i1) {
const int64_t i10 = i1;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
@ -9963,7 +9963,6 @@ static void ggml_compute_forward_out_prod_f32(
// }
}
}
}
//int64_t t1 = ggml_perf_time_us();
//static int64_t acc = 0;