improve ggml_out_prod performance

- change iteration order (>15s -> 10s runtime)
- parallelize over one more dimension: over dst matrix rows (10s -> <5s runtime)
This commit is contained in:
xaedes 2023-05-15 14:42:24 +02:00
parent 19fb91899b
commit ec881156f6
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

49
ggml.c
View file

@ -9917,51 +9917,50 @@ static void ggml_compute_forward_out_prod_f32(
return; return;
} }
// parallelize by last two dimensions // parallelize by last three dimensions
// total parallel in src0 // total rows in dst
const int64_t np = ne02*ne03; const int64_t nr = ne1*ne2*ne3;
// per thread // rows per thread
const int64_t dp = (np + nth - 1)/nth; const int64_t dr = (nr + nth - 1)/nth;
// range for this thread // row range for this thread
const int64_t ip0 = dp*ith; const int64_t ir0 = dr*ith;
const int64_t ip1 = MIN(ip0 + dp, np); const int64_t ir1 = MIN(ir0 + dr, nr);
// dst[:,:,:,:] = 0 // dst[:,:,:,:] = 0
// for i2,i3: // for i2,i3:
// for i01: // for i1:
// for i1: // for i01:
// for i0: // for i0:
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
for (int64_t ip = ip0; ip < ip1; ++ip) { for (int64_t ir = ir0; ir < ir1; ++ir) {
// src0 indices // dst indices
const int64_t i3 = ip/ne02; const int64_t i3 = ir/(ne2*ne1);
const int64_t i2 = ip - i3*ne02; const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int64_t i02 = i2; const int64_t i02 = i2;
const int64_t i03 = i3; const int64_t i03 = i3;
const int64_t i10 = i1;
const int64_t i12 = i2; const int64_t i12 = i2;
const int64_t i13 = i3; const int64_t i13 = i3;
for (int64_t i01 = 0; i01 < ne01; ++i01) { for (int64_t i01 = 0; i01 < ne01; ++i01) {
const int64_t i11 = i01; const int64_t i11 = i01;
for (int64_t i1 = 0; i1 < ne1; ++i1) { float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
const int64_t i10 = i1; float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); ggml_vec_mad_f32(ne0, d, s0, *s1);
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); // for (int64_t i0 = 0; i0 < ne0; ++i0) {
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); // d[i0] += s0[i0] * s1[i1];
// }
ggml_vec_mad_f32(ne0, d, s0, *s1);
// for (int64_t i0 = 0; i0 < ne0; ++i0) {
// d[i0] += s0[i0] * s1[i1];
// }
}
} }
} }