From d88dae2980c7d01590eb60afe653cc04fe34ee96 Mon Sep 17 00:00:00 2001 From: xaedes Date: Thu, 14 Sep 2023 18:39:46 +0200 Subject: [PATCH] block tiling for out-prod inspired by mul-mat block sizes are empirically optimized roughly doubles the flops of out-prod --- ggml.c | 75 +++++++++++++++++++++++++++++++++------------------------- 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/ggml.c b/ggml.c index f88aba042..e88c046aa 100644 --- a/ggml.c +++ b/ggml.c @@ -11817,49 +11817,60 @@ static void ggml_compute_forward_out_prod_f32( const int64_t ir0 = dr*ith; const int64_t ir1 = MIN(ir0 + dr, nr); - for (int64_t ir = ir0; ir < ir1; ++ir) { - // dst indices - const int64_t i3 = ir/(ne2*ne1); - const int64_t i2 = (ir - i3*ne2*ne1)/ne1; - const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + // block-tiling attempt + const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); + const int64_t blck_1 = 16; - const int64_t i02 = i2; - const int64_t i03 = i3; + for (int64_t bir = ir0; bir < ir1; bir += blck_1) { + const int64_t bir1 = MIN(bir + blck_1, ir1); + for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { + const int64_t bne01 = MIN(bi01 + blck_0, ne01); + for (int64_t ir = bir; ir < bir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2; + const int64_t i03 = i3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; - //const int64_t i10 = i1; - const int64_t i12 = i2; - const int64_t i13 = i3; #if GGML_VEC_MAD_UNROLL > 2 - const int64_t ne01_unroll = ne01 - (ne01 % GGML_VEC_MAD_UNROLL); - for (int64_t i01 = 0; i01 < ne01_unroll; i01 += GGML_VEC_MAD_UNROLL) { - const int64_t i11 = i01; + const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL); + for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) { + const int64_t i11 = i01; - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); - } - for (int64_t i01 = ne01_unroll; i01 < ne01; ++i01) { - const int64_t i11 = i01; + ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); + } + for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) { + const int64_t i11 = i01; - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - ggml_vec_mad_f32(ne0, d, s0, *s1); - } + ggml_vec_mad_f32(ne0, d, s0, *s1); + } #else - for (int64_t i01 = 0; i01 < ne01; ++i01) { - const int64_t i11 = i01; + for (int64_t i01 = bi01; i01 < bne01; ++i01) { + const int64_t i11 = i01; - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - ggml_vec_mad_f32(ne0, d, s0, *s1); - } + ggml_vec_mad_f32(ne0, d, s0, *s1); + } #endif + } + } }