diff --git a/ggml.c b/ggml.c index 50e3d1775..3b8b4f4b3 100644 --- a/ggml.c +++ b/ggml.c @@ -7646,28 +7646,52 @@ static void ggml_compute_forward_add1_f32( const int ith = params->ith; const int nth = params->nth; - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + const int64_t ne0 = src0->ne[0]; + const int64_t ne1 = src0->ne[1]; + const int64_t ne2 = src0->ne[2]; const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb10 = src1->nb[0]; + const size_t nb11 = src1->nb[1]; + const size_t nb12 = src1->nb[2]; + const size_t nb13 = src1->nb[3]; const size_t nb0 = dst->nb[0]; const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - for (int j = ith; j < n; j += nth) { + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + #ifdef GGML_USE_ACCELERATE vDSP_vadd( - (float *) ((char *) src0->data + j*nb01), 1, + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, (float *) ((char *) src1->data), 0, - (float *) ((char *) dst->data + j*nb1), 1, nc); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, + ne0); #else - ggml_vec_add1_f32(nc, - (float *) ((char *) dst->data + j*nb1), - (float *) ((char *) src0->data + j*nb01), + ggml_vec_add1_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), v); #endif } @@ -7691,14 +7715,20 @@ static void ggml_compute_forward_add1_f16_f32( const int ith = params->ith; const int nth = params->nth; - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + const int64_t ne0 = src0->ne[0]; + const int64_t ne1 = src0->ne[1]; + const int64_t ne2 = src0->ne[2]; const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; const size_t nb0 = dst->nb[0]; const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -7707,10 +7737,22 @@ static void ggml_compute_forward_add1_f16_f32( GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - for (int j = ith; j < n; j += nth) { - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); - for (int i = 0; i < nc; i++) { + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v); } } @@ -7734,14 +7776,20 @@ static void ggml_compute_forward_add1_f16_f16( const int ith = params->ith; const int nth = params->nth; - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + const int64_t ne0 = src0->ne[0]; + const int64_t ne1 = src0->ne[1]; + const int64_t ne2 = src0->ne[2]; const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; const size_t nb0 = dst->nb[0]; const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F16); @@ -7750,10 +7798,22 @@ static void ggml_compute_forward_add1_f16_f16( GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - for (int j = ith; j < n; j += nth) { - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); - for (int i = 0; i < nc; i++) { + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v); } } @@ -7774,38 +7834,23 @@ static void ggml_compute_forward_add1_q_f32( // scalar to add const float v = *(float *) src1->data; - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - //const int64_t ne10 = src1->ne[0]; - //const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - //const int64_t ne0 = dst->ne[0]; - //const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - const int ith = params->ith; const int nth = params->nth; - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); + const int nr = ggml_nrows(src0); + const int64_t ne0 = src0->ne[0]; + const int64_t ne1 = src0->ne[1]; + const int64_t ne2 = src0->ne[2]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; const enum ggml_type type = src0->type; dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; @@ -7823,9 +7868,6 @@ static void ggml_compute_forward_add1_q_f32( GGML_ASSERT(dst->type == src0->type); GGML_ASSERT(src1->type == GGML_TYPE_F32); - // total rows in src0 - const int nr = ne01*ne02*ne03; - // rows per thread const int dr = (nr + nth - 1)/nth; @@ -7833,30 +7875,25 @@ static void ggml_compute_forward_add1_q_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - // dst is same shape as src0 => same indices - const int i3 = i03; - const int i2 = i02; - const int i1 = i01; + void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03)); + void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 )); - void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0)); - - assert(ne00 % 32 == 0); + assert(ne0 % 32 == 0); // unquantize row from src0 to temp buffer - dequantize_row_q(src0_row, wdata, ne00); + dequantize_row_q(src0_row, wdata, ne0); // add src1 - ggml_vec_acc1_f32(ne00, wdata, v); + ggml_vec_acc1_f32(ne0, wdata, v); // quantize row to dst - quantize_row_q(wdata, dst_row, ne00); + quantize_row_q(wdata, dst_row, ne0); } }