From 807c8252ced16419e623c27299055abb2dea9fe7 Mon Sep 17 00:00:00 2001 From: Kunnis Date: Thu, 9 May 2024 23:50:37 -0500 Subject: [PATCH] Add in the re-chunking code. --- ggml.c | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/ggml.c b/ggml.c index f04a11b69..b6f90300a 100644 --- a/ggml.c +++ b/ggml.c @@ -12062,9 +12062,6 @@ UseGgmlGemm1:; UseGgmlGemm2:; #endif - int chunk_size = 16; - UNUSED(chunk_size); - #ifdef GGML_PERF int chunks_executed = 0; UNUSED(chunks_executed); @@ -12084,12 +12081,32 @@ UseGgmlGemm2:; num_rows_per_vec_dot = 1; } + //Now select a reasonable chunk size. + int chunk_size = 16; + + //We need to step up the size if it's small + if (nr0 == 1 || nr1 == 1) + chunk_size = 64; + + // distribute the work across the inner or outer loop based on which one is larger + //The number of chunks in the 0/1 dim. + //CEIL(nr0/chunk_size) + int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; + int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; + //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); - // distribute the thread work across the inner or outer loop based on which one is larger + //If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. + if (nchunk0 * nchunk1 < nth * 4) + { + // distribute the thread work across the inner or outer loop based on which one is larger + nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + } - const int64_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - const int64_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + //The number of elements in each chunk + const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; //The first chunk comes from our thread_id, the rest will get auto-assigned. int current_chunk = ith; @@ -12099,9 +12116,6 @@ UseGgmlGemm2:; const int64_t ith0 = current_chunk % nchunk0; const int64_t ith1 = current_chunk / nchunk0; - const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - const int64_t ir0_start = dr0 * ith0; const int64_t ir0_end = MIN(ir0_start + dr0, nr0);