block tiling for out-prod inspired by mul-mat
block sizes are empirically optimized roughly doubles the flops of out-prod
This commit is contained in:
parent
0971fee710
commit
d88dae2980
1 changed files with 43 additions and 32 deletions
21
ggml.c
21
ggml.c
|
@ -11817,7 +11817,15 @@ static void ggml_compute_forward_out_prod_f32(
|
||||||
const int64_t ir0 = dr*ith;
|
const int64_t ir0 = dr*ith;
|
||||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
// block-tiling attempt
|
||||||
|
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
|
||||||
|
const int64_t blck_1 = 16;
|
||||||
|
|
||||||
|
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
|
||||||
|
const int64_t bir1 = MIN(bir + blck_1, ir1);
|
||||||
|
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
|
||||||
|
const int64_t bne01 = MIN(bi01 + blck_0, ne01);
|
||||||
|
for (int64_t ir = bir; ir < bir1; ++ir) {
|
||||||
// dst indices
|
// dst indices
|
||||||
const int64_t i3 = ir/(ne2*ne1);
|
const int64_t i3 = ir/(ne2*ne1);
|
||||||
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
||||||
|
@ -11829,9 +11837,10 @@ static void ggml_compute_forward_out_prod_f32(
|
||||||
//const int64_t i10 = i1;
|
//const int64_t i10 = i1;
|
||||||
const int64_t i12 = i2;
|
const int64_t i12 = i2;
|
||||||
const int64_t i13 = i3;
|
const int64_t i13 = i3;
|
||||||
|
|
||||||
#if GGML_VEC_MAD_UNROLL > 2
|
#if GGML_VEC_MAD_UNROLL > 2
|
||||||
const int64_t ne01_unroll = ne01 - (ne01 % GGML_VEC_MAD_UNROLL);
|
const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
|
||||||
for (int64_t i01 = 0; i01 < ne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
|
for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
|
||||||
const int64_t i11 = i01;
|
const int64_t i11 = i01;
|
||||||
|
|
||||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||||
|
@ -11840,7 +11849,7 @@ static void ggml_compute_forward_out_prod_f32(
|
||||||
|
|
||||||
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
|
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
|
||||||
}
|
}
|
||||||
for (int64_t i01 = ne01_unroll; i01 < ne01; ++i01) {
|
for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
|
||||||
const int64_t i11 = i01;
|
const int64_t i11 = i01;
|
||||||
|
|
||||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||||
|
@ -11850,7 +11859,7 @@ static void ggml_compute_forward_out_prod_f32(
|
||||||
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
for (int64_t i01 = bi01; i01 < bne01; ++i01) {
|
||||||
const int64_t i11 = i01;
|
const int64_t i11 = i01;
|
||||||
|
|
||||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||||
|
@ -11861,6 +11870,8 @@ static void ggml_compute_forward_out_prod_f32(
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//int64_t t1 = ggml_perf_time_us();
|
//int64_t t1 = ggml_perf_time_us();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue