improve ggml_out_prod performance
- change iteration order (>15s -> 10s runtime) - parallelize over one more dimension: over dst matrix rows (10s -> <5s runtime)
This commit is contained in:
parent
19fb91899b
commit
ec881156f6
1 changed files with 25 additions and 26 deletions
33
ggml.c
33
ggml.c
|
@ -9917,42 +9917,42 @@ static void ggml_compute_forward_out_prod_f32(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// parallelize by last two dimensions
|
// parallelize by last three dimensions
|
||||||
|
|
||||||
// total parallel in src0
|
// total rows in dst
|
||||||
const int64_t np = ne02*ne03;
|
const int64_t nr = ne1*ne2*ne3;
|
||||||
|
|
||||||
// per thread
|
// rows per thread
|
||||||
const int64_t dp = (np + nth - 1)/nth;
|
const int64_t dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
// range for this thread
|
// row range for this thread
|
||||||
const int64_t ip0 = dp*ith;
|
const int64_t ir0 = dr*ith;
|
||||||
const int64_t ip1 = MIN(ip0 + dp, np);
|
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
// dst[:,:,:,:] = 0
|
// dst[:,:,:,:] = 0
|
||||||
// for i2,i3:
|
// for i2,i3:
|
||||||
// for i01:
|
|
||||||
// for i1:
|
// for i1:
|
||||||
|
// for i01:
|
||||||
// for i0:
|
// for i0:
|
||||||
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
||||||
|
|
||||||
for (int64_t ip = ip0; ip < ip1; ++ip) {
|
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||||
// src0 indices
|
// dst indices
|
||||||
const int64_t i3 = ip/ne02;
|
const int64_t i3 = ir/(ne2*ne1);
|
||||||
const int64_t i2 = ip - i3*ne02;
|
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
||||||
|
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||||
|
|
||||||
const int64_t i02 = i2;
|
const int64_t i02 = i2;
|
||||||
const int64_t i03 = i3;
|
const int64_t i03 = i3;
|
||||||
|
|
||||||
|
const int64_t i10 = i1;
|
||||||
const int64_t i12 = i2;
|
const int64_t i12 = i2;
|
||||||
const int64_t i13 = i3;
|
const int64_t i13 = i3;
|
||||||
|
|
||||||
|
|
||||||
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
||||||
const int64_t i11 = i01;
|
const int64_t i11 = i01;
|
||||||
|
|
||||||
for (int64_t i1 = 0; i1 < ne1; ++i1) {
|
|
||||||
const int64_t i10 = i1;
|
|
||||||
|
|
||||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||||
|
@ -9963,7 +9963,6 @@ static void ggml_compute_forward_out_prod_f32(
|
||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
//int64_t t1 = ggml_perf_time_us();
|
//int64_t t1 = ggml_perf_time_us();
|
||||||
//static int64_t acc = 0;
|
//static int64_t acc = 0;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue