improve ggml_out_prod performance
- change iteration order (>15s -> 10s runtime) - parallelize over one more dimension: over dst matrix rows (10s -> <5s runtime)
This commit is contained in:
parent
19fb91899b
commit
ec881156f6
1 changed files with 25 additions and 26 deletions
33
ggml.c
33
ggml.c
|
@ -9917,42 +9917,42 @@ static void ggml_compute_forward_out_prod_f32(
|
|||
return;
|
||||
}
|
||||
|
||||
// parallelize by last two dimensions
|
||||
// parallelize by last three dimensions
|
||||
|
||||
// total parallel in src0
|
||||
const int64_t np = ne02*ne03;
|
||||
// total rows in dst
|
||||
const int64_t nr = ne1*ne2*ne3;
|
||||
|
||||
// per thread
|
||||
const int64_t dp = (np + nth - 1)/nth;
|
||||
// rows per thread
|
||||
const int64_t dr = (nr + nth - 1)/nth;
|
||||
|
||||
// range for this thread
|
||||
const int64_t ip0 = dp*ith;
|
||||
const int64_t ip1 = MIN(ip0 + dp, np);
|
||||
// row range for this thread
|
||||
const int64_t ir0 = dr*ith;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
// dst[:,:,:,:] = 0
|
||||
// for i2,i3:
|
||||
// for i01:
|
||||
// for i1:
|
||||
// for i01:
|
||||
// for i0:
|
||||
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
||||
|
||||
for (int64_t ip = ip0; ip < ip1; ++ip) {
|
||||
// src0 indices
|
||||
const int64_t i3 = ip/ne02;
|
||||
const int64_t i2 = ip - i3*ne02;
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
// dst indices
|
||||
const int64_t i3 = ir/(ne2*ne1);
|
||||
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i03 = i3;
|
||||
|
||||
const int64_t i10 = i1;
|
||||
const int64_t i12 = i2;
|
||||
const int64_t i13 = i3;
|
||||
|
||||
|
||||
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
||||
const int64_t i11 = i01;
|
||||
|
||||
for (int64_t i1 = 0; i1 < ne1; ++i1) {
|
||||
const int64_t i10 = i1;
|
||||
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
@ -9963,7 +9963,6 @@ static void ggml_compute_forward_out_prod_f32(
|
|||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//int64_t t1 = ggml_perf_time_us();
|
||||
//static int64_t acc = 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue