From 700c782dc14972a34389fffcfc4783dd0db87ac5 Mon Sep 17 00:00:00 2001 From: Kunnis Date: Thu, 9 May 2024 22:18:26 -0500 Subject: [PATCH] Reorg the code. --- ggml.c | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/ggml.c b/ggml.c index 6d368185a..e316c2ba1 100644 --- a/ggml.c +++ b/ggml.c @@ -12056,8 +12056,24 @@ UseGgmlGemm1:; UseGgmlGemm2:; #endif - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = ne1*ne12*ne13; // src1 rows +#ifdef GGML_PERF + int chunks_executed = 0; + UNUSED(chunks_executed); +#endif + + //This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) + const int64_t nr0 = ne0; + + //This is the size of the rest of the dimensions of the result + const int64_t nr1 = ne1 * ne2 * ne3; + + // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols + int64_t num_rows_per_vec_dot = vec_dot_num_rows; + // TODO: currently the mmla kernels support only even numbered rows/cols. + // this check can be removed once they are extended to support odd numbered rows/cols too + if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { + num_rows_per_vec_dot = 1; + } //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); @@ -12084,15 +12100,14 @@ UseGgmlGemm2:; return; } - // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols - int64_t num_rows_per_vec_dot = vec_dot_num_rows; - // TODO: currently the mmla kernels support only even numbered rows/cols. - // this check can be removed once they are extended to support odd numbered rows/cols too - if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { - num_rows_per_vec_dot = 1; - } - ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); + +#ifdef GGML_PERF + // These numbers are useful when trying to measure how well the threading scheduling works. + //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1; + //float time = (ggml_perf_time_us() - t0); + //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed); +#endif } // ggml_compute_forward_mul_mat_id