From 1b6fd5470bdb9b041ed4f8629a8bf09c418a72a2 Mon Sep 17 00:00:00 2001 From: Will Beddow Date: Fri, 7 Apr 2023 00:49:19 -0400 Subject: [PATCH] WIP on f16 --- ggml.c | 45 ++++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/ggml.c b/ggml.c index 3d2a825cb..57fce983c 100644 --- a/ggml.c +++ b/ggml.c @@ -4936,16 +4936,32 @@ static void ggml_compute_forward_dup_f16( int64_t i12 = 0; int64_t i13 = 0; + const int thread_num = params->ith; + const int total_threads = params->nth; + const int64_t regions = ne03 * ne02 * ne01 * ne00; + + const int64_t regions_per_thread = (regions + total_threads - 1) / total_threads; + + const int64_t thread_start_region = regions_per_thread * thread_num; + const int64_t thread_stop_region = MIN(thread_start_region + regions_per_thread, regions); + + int region_index = 0; + if (dst->type == GGML_TYPE_F16) { for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + if (region_index > thread_stop_region) break; - memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t)); + if (region_index++ >= thread_start_region) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t)); + } + + // Regardless, we have to keep the dst counters updated if (++i10 == ne00) { i10 = 0; if (++i11 == ne01) { @@ -4967,11 +4983,17 @@ static void ggml_compute_forward_dup_f16( for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr); + if (region_index > thread_stop_region) break; + if (region_index++ >= thread_start_region) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr); + } + + // Regardless, we have to keep the dst counters updated if (++i10 == ne00) { i10 = 0; if (++i11 == ne01) { @@ -5029,7 +5051,6 @@ static void ggml_compute_forward_dup_f32( int64_t i12 = 0; int64_t i13 = 0; - const int thread_num = params->ith; const int total_threads = params->nth; const int64_t regions = ne03 * ne02 * ne01 * ne00; @@ -5046,11 +5067,9 @@ static void ggml_compute_forward_dup_f32( for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { - if (region_index > thread_stop_region) break; - - if (region_index >= thread_start_region) { + if (region_index++ >= thread_start_region) { const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); @@ -5079,10 +5098,9 @@ static void ggml_compute_forward_dup_f32( for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { - if (region_index > thread_stop_region) break; - if (region_index >= thread_start_region) { + if (region_index++ >= thread_start_region) { const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); @@ -7302,7 +7320,6 @@ static void ggml_compute_forward_rope_f32( // row index used to determine which thread to use int ir = 0; - int i = 0; for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { @@ -7312,7 +7329,6 @@ static void ggml_compute_forward_rope_f32( if (ir > ir1) break; for (int i0 = 0; i0 < n_dims; i0 += 2) { - ++i; const float theta = powf(10000.0, ((float)-i0)/n_dims); const float cos_theta = cosf(p*theta); @@ -7330,7 +7346,6 @@ static void ggml_compute_forward_rope_f32( } } } - } static void ggml_compute_forward_rope_f16(