diff --git a/ggml.c b/ggml.c index 57fce983c..0d88699c6 100644 --- a/ggml.c +++ b/ggml.c @@ -4938,12 +4938,6 @@ static void ggml_compute_forward_dup_f16( const int thread_num = params->ith; const int total_threads = params->nth; - const int64_t regions = ne03 * ne02 * ne01 * ne00; - - const int64_t regions_per_thread = (regions + total_threads - 1) / total_threads; - - const int64_t thread_start_region = regions_per_thread * thread_num; - const int64_t thread_stop_region = MIN(thread_start_region + regions_per_thread, regions); int region_index = 0; @@ -4952,9 +4946,9 @@ static void ggml_compute_forward_dup_f16( for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { - if (region_index > thread_stop_region) break; - if (region_index++ >= thread_start_region) { + // Interleave execution so that in a 4 thread run thread 0 copies regions 0,4,8, ... + if ((region_index++ % total_threads) == thread_num) { const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); @@ -4984,9 +4978,8 @@ static void ggml_compute_forward_dup_f16( for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { - if (region_index > thread_stop_region) break; - - if (region_index++ >= thread_start_region) { + // Interleave execution so that in a 4 thread run thread 0 copies regions 0,4,8, ... + if ((region_index++ % total_threads) == thread_num) { const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); @@ -5053,12 +5046,6 @@ static void ggml_compute_forward_dup_f32( const int thread_num = params->ith; const int total_threads = params->nth; - const int64_t regions = ne03 * ne02 * ne01 * ne00; - - const int64_t regions_per_thread = (regions + total_threads - 1) / total_threads; - - const int64_t thread_start_region = regions_per_thread * thread_num; - const int64_t thread_stop_region = MIN(thread_start_region + regions_per_thread, regions); int region_index = 0; @@ -5067,9 +5054,9 @@ static void ggml_compute_forward_dup_f32( for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { - if (region_index > thread_stop_region) break; - if (region_index++ >= thread_start_region) { + // Interleave execution so that in a 4 thread run thread 0 copies regions 0,4,8, ... + if ((region_index++ % total_threads) == thread_num) { const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); @@ -5098,9 +5085,7 @@ static void ggml_compute_forward_dup_f32( for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { - if (region_index > thread_stop_region) break; - - if (region_index++ >= thread_start_region) { + if ((region_index++ % total_threads) == thread_num) { const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);