From 4f95478ea3d4ec370efada288c6e1345778d425d Mon Sep 17 00:00:00 2001 From: Kunnis Date: Thu, 9 May 2024 00:07:40 -0500 Subject: [PATCH] Moving some variable definitions into the chunk function. --- ggml.c | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/ggml.c b/ggml.c index 23ecc7535..3b5e0719a 100644 --- a/ggml.c +++ b/ggml.c @@ -11779,10 +11779,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( const int64_t ir1_end, const bool src1_cont, - const int64_t r2, - const int64_t r3, enum ggml_type vec_dot_type, - const void* wdata, const size_t row_size, const ggml_vec_dot_t vec_dot ) { @@ -11794,6 +11791,14 @@ static void ggml_compute_forward_mul_mat_one_chunk( const enum ggml_type type = src0->type; + const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + + // broadcast factors + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + + //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end); + assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); @@ -11891,8 +11896,8 @@ static void ggml_compute_forward_mul_mat( GGML_ASSERT(nb2 <= nb3); // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; // nb01 >= nb00 - src0 is not transposed // compute by src0 rows @@ -12070,8 +12075,6 @@ UseGgmlGemm2:; const int64_t ir1_start = dr1*ith1; const int64_t ir1_end = MIN(ir1_start + dr1, nr1); - //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end); - // threads with no work simply yield (not sure if it helps) if (ir0_start >= ir0_end || ir1_start >= ir1_end) { sched_yield(); @@ -12086,7 +12089,7 @@ UseGgmlGemm2:; num_rows_per_vec_dot = 1; } - ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end, src1_cont, r2, r3, vec_dot_type, wdata, row_size, vec_dot); + ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end, src1_cont, vec_dot_type, row_size, vec_dot); } // ggml_compute_forward_mul_mat_id