diff --git a/ggml.c b/ggml.c index 6296809a4..f88aba042 100644 --- a/ggml.c +++ b/ggml.c @@ -134,6 +134,7 @@ typedef void * thread_ret_t; #define GGML_SOFT_MAX_UNROLL 4 #define GGML_VEC_DOT_UNROLL 2 +#define GGML_VEC_MAD_UNROLL 32 // // logging @@ -3707,6 +3708,58 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +// xs and vs are byte strides of x and v +inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { + + const float * restrict x[GGML_VEC_MAD_UNROLL]; + const float * restrict v[GGML_VEC_MAD_UNROLL]; + + for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) { + x[i] = (const float *) ((const char *) xv + i*xs); + v[i] = (const float *) ((const char *) vv + i*vs); + } + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + vx[k] = GGML_F32_VEC_SET1(v[k][0]); + } + + GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + } + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = np; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#else + // scalar + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = 0; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#endif +} + //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #if defined(GGML_USE_ACCELERATE) @@ -11745,6 +11798,13 @@ static void ggml_compute_forward_out_prod_f32( return; } + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + // parallelize by last three dimensions // total rows in dst @@ -11757,13 +11817,6 @@ static void ggml_compute_forward_out_prod_f32( const int64_t ir0 = dr*ith; const int64_t ir1 = MIN(ir0 + dr, nr); - // dst[:,:,:,:] = 0 - // for i2,i3: - // for i1: - // for i01: - // for i0: - // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] - for (int64_t ir = ir0; ir < ir1; ++ir) { // dst indices const int64_t i3 = ir/(ne2*ne1); @@ -11776,7 +11829,27 @@ static void ggml_compute_forward_out_prod_f32( //const int64_t i10 = i1; const int64_t i12 = i2; const int64_t i13 = i3; +#if GGML_VEC_MAD_UNROLL > 2 + const int64_t ne01_unroll = ne01 - (ne01 % GGML_VEC_MAD_UNROLL); + for (int64_t i01 = 0; i01 < ne01_unroll; i01 += GGML_VEC_MAD_UNROLL) { + const int64_t i11 = i01; + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); + } + for (int64_t i01 = ne01_unroll; i01 < ne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#else for (int64_t i01 = 0; i01 < ne01; ++i01) { const int64_t i11 = i01; @@ -11785,12 +11858,11 @@ static void ggml_compute_forward_out_prod_f32( float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); ggml_vec_mad_f32(ne0, d, s0, *s1); - // for (int64_t i0 = 0; i0 < ne0; ++i0) { - // d[i0] += s0[i0] * s1[i1]; - // } } +#endif } + //int64_t t1 = ggml_perf_time_us(); //static int64_t acc = 0; //acc += t1 - t0;