block tiling for out-prod inspired by mul-mat

block sizes are empirically optimized

roughly doubles the flops of out-prod
This commit is contained in:
xaedes 2023-09-14 18:39:46 +02:00
parent 0971fee710
commit d88dae2980
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

75
ggml.c
View file

@ -11817,49 +11817,60 @@ static void ggml_compute_forward_out_prod_f32(
const int64_t ir0 = dr*ith; const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr); const int64_t ir1 = MIN(ir0 + dr, nr);
for (int64_t ir = ir0; ir < ir1; ++ir) { // block-tiling attempt
// dst indices const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
const int64_t i3 = ir/(ne2*ne1); const int64_t blck_1 = 16;
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int64_t i02 = i2; for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
const int64_t i03 = i3; const int64_t bir1 = MIN(bir + blck_1, ir1);
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
const int64_t bne01 = MIN(bi01 + blck_0, ne01);
for (int64_t ir = bir; ir < bir1; ++ir) {
// dst indices
const int64_t i3 = ir/(ne2*ne1);
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int64_t i02 = i2;
const int64_t i03 = i3;
//const int64_t i10 = i1;
const int64_t i12 = i2;
const int64_t i13 = i3;
//const int64_t i10 = i1;
const int64_t i12 = i2;
const int64_t i13 = i3;
#if GGML_VEC_MAD_UNROLL > 2 #if GGML_VEC_MAD_UNROLL > 2
const int64_t ne01_unroll = ne01 - (ne01 % GGML_VEC_MAD_UNROLL); const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
for (int64_t i01 = 0; i01 < ne01_unroll; i01 += GGML_VEC_MAD_UNROLL) { for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
const int64_t i11 = i01; const int64_t i11 = i01;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
} }
for (int64_t i01 = ne01_unroll; i01 < ne01; ++i01) { for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
const int64_t i11 = i01; const int64_t i11 = i01;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
ggml_vec_mad_f32(ne0, d, s0, *s1); ggml_vec_mad_f32(ne0, d, s0, *s1);
} }
#else #else
for (int64_t i01 = 0; i01 < ne01; ++i01) { for (int64_t i01 = bi01; i01 < bne01; ++i01) {
const int64_t i11 = i01; const int64_t i11 = i01;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
ggml_vec_mad_f32(ne0, d, s0, *s1); ggml_vec_mad_f32(ne0, d, s0, *s1);
} }
#endif #endif
}
}
} }