From a968553c6f1d633f2cd3264220bd731c79e0a9fa Mon Sep 17 00:00:00 2001 From: Kunnis Date: Wed, 8 May 2024 23:43:43 -0500 Subject: [PATCH] Renaming and moving a bunch of variables around. --- ggml.c | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/ggml.c b/ggml.c index 06d2c1bf6..89c6bc332 100644 --- a/ggml.c +++ b/ggml.c @@ -11983,44 +11983,44 @@ UseGgmlGemm2:; const int64_t dr0 = (nr0 + nth0 - 1)/nth0; const int64_t dr1 = (nr1 + nth1 - 1)/nth1; - const int64_t ir010 = dr0*ith0; - const int64_t ir011 = MIN(ir010 + dr0, nr0); + const int64_t ir0_start = dr0*ith0; + const int64_t ir0_end = MIN(ir0_start + dr0, nr0); - const int64_t ir110 = dr1*ith1; - const int64_t ir111 = MIN(ir110 + dr1, nr1); + const int64_t ir1_start = dr1*ith1; + const int64_t ir1_end = MIN(ir1_start + dr1, nr1); - //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); + //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end); // threads with no work simply yield (not sure if it helps) - if (ir010 >= ir011 || ir110 >= ir111) { + if (ir0_start >= ir0_end || ir1_start >= ir1_end) { sched_yield(); return; } - assert(ne12 % ne02 == 0); - assert(ne13 % ne03 == 0); + // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols + int64_t num_rows_per_vec_dot = vec_dot_num_rows; + // TODO: currently the mmla kernels support only even numbered rows/cols. + // this check can be removed once they are extended to support odd numbered rows/cols too + if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { + num_rows_per_vec_dot = 1; + } + + assert(ne12% ne02 == 0); + assert(ne13% ne03 == 0); // block-tiling attempt const int64_t blck_0 = 16; const int64_t blck_1 = 16; - // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols - int64_t nrc = vec_dot_num_rows; - // TODO: currently the mmla kernels support only even numbered rows/cols. - // this check can be removed once they are extended to support odd numbered rows/cols too - if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { - nrc = 1; - } - const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; // attempt to reduce false-sharing (does not seem to make a difference) // 16 * 2, accounting for mmla kernels float tmp[32]; - for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { - for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { - for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) { + for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { const int64_t i13 = (ir1/(ne12*ne1)); const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1; const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1); @@ -12045,16 +12045,16 @@ UseGgmlGemm2:; : (i11*nb11 + i12*nb12 + i13*nb13)); float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); //} - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) { - vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc); + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot >1 ? 16 : 0), src0_row + ir0*nb01, (num_rows_per_vec_dot >1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot >1 ? src1_col_stride : 0), num_rows_per_vec_dot); } - for (int cn = 0; cn < nrc; ++cn) { - memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { + memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float)); } } }