block tiling for out-prod inspired by mul-mat

block sizes are empirically optimized

roughly doubles the flops of out-prod
This commit is contained in:
xaedes 2023-09-14 18:39:46 +02:00
parent 0971fee710
commit d88dae2980
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

21
ggml.c
View file

@ -11817,7 +11817,15 @@ static void ggml_compute_forward_out_prod_f32(
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
for (int64_t ir = ir0; ir < ir1; ++ir) {
// block-tiling attempt
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
const int64_t blck_1 = 16;
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
const int64_t bir1 = MIN(bir + blck_1, ir1);
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
const int64_t bne01 = MIN(bi01 + blck_0, ne01);
for (int64_t ir = bir; ir < bir1; ++ir) {
// dst indices
const int64_t i3 = ir/(ne2*ne1);
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
@ -11829,9 +11837,10 @@ static void ggml_compute_forward_out_prod_f32(
//const int64_t i10 = i1;
const int64_t i12 = i2;
const int64_t i13 = i3;
#if GGML_VEC_MAD_UNROLL > 2
const int64_t ne01_unroll = ne01 - (ne01 % GGML_VEC_MAD_UNROLL);
for (int64_t i01 = 0; i01 < ne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
const int64_t i11 = i01;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
@ -11840,7 +11849,7 @@ static void ggml_compute_forward_out_prod_f32(
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
}
for (int64_t i01 = ne01_unroll; i01 < ne01; ++i01) {
for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
const int64_t i11 = i01;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
@ -11850,7 +11859,7 @@ static void ggml_compute_forward_out_prod_f32(
ggml_vec_mad_f32(ne0, d, s0, *s1);
}
#else
for (int64_t i01 = 0; i01 < ne01; ++i01) {
for (int64_t i01 = bi01; i01 < bne01; ++i01) {
const int64_t i11 = i01;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
@ -11861,6 +11870,8 @@ static void ggml_compute_forward_out_prod_f32(
}
#endif
}
}
}
//int64_t t1 = ggml_perf_time_us();