block tiling for out-prod inspired by mul-mat
block sizes are empirically optimized roughly doubles the flops of out-prod
This commit is contained in:
parent
0971fee710
commit
d88dae2980
1 changed files with 43 additions and 32 deletions
75
ggml.c
75
ggml.c
|
@ -11817,49 +11817,60 @@ static void ggml_compute_forward_out_prod_f32(
|
|||
const int64_t ir0 = dr*ith;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
// dst indices
|
||||
const int64_t i3 = ir/(ne2*ne1);
|
||||
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
// block-tiling attempt
|
||||
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
|
||||
const int64_t blck_1 = 16;
|
||||
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i03 = i3;
|
||||
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
|
||||
const int64_t bir1 = MIN(bir + blck_1, ir1);
|
||||
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
|
||||
const int64_t bne01 = MIN(bi01 + blck_0, ne01);
|
||||
for (int64_t ir = bir; ir < bir1; ++ir) {
|
||||
// dst indices
|
||||
const int64_t i3 = ir/(ne2*ne1);
|
||||
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i03 = i3;
|
||||
|
||||
//const int64_t i10 = i1;
|
||||
const int64_t i12 = i2;
|
||||
const int64_t i13 = i3;
|
||||
|
||||
//const int64_t i10 = i1;
|
||||
const int64_t i12 = i2;
|
||||
const int64_t i13 = i3;
|
||||
#if GGML_VEC_MAD_UNROLL > 2
|
||||
const int64_t ne01_unroll = ne01 - (ne01 % GGML_VEC_MAD_UNROLL);
|
||||
for (int64_t i01 = 0; i01 < ne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
|
||||
const int64_t i11 = i01;
|
||||
const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
|
||||
for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
|
||||
const int64_t i11 = i01;
|
||||
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
|
||||
}
|
||||
for (int64_t i01 = ne01_unroll; i01 < ne01; ++i01) {
|
||||
const int64_t i11 = i01;
|
||||
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
|
||||
}
|
||||
for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
|
||||
const int64_t i11 = i01;
|
||||
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
||||
}
|
||||
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
||||
}
|
||||
#else
|
||||
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
||||
const int64_t i11 = i01;
|
||||
for (int64_t i01 = bi01; i01 < bne01; ++i01) {
|
||||
const int64_t i11 = i01;
|
||||
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
||||
}
|
||||
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue