diff --git a/ggml.c b/ggml.c index 0e4b1466c..dea0e9f0b 100644 --- a/ggml.c +++ b/ggml.c @@ -5651,8 +5651,59 @@ static void ggml_compute_forward_rms_norm_f32( const size_t nb2 = dst->nb[2]; const size_t nb3 = dst->nb[3]; - const ggml_float eps = 1e-6f; // TODO: make this a parameter + const ggml_float eps = 1e-6; // TODO: make this a parameter +#if defined(GGML_SIMD) + const int np = (ne00 & ~(GGML_F32_STEP - 1)); + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + // compute sum of squares of x + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + GGML_F32_VEC ax[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ax[j]); + } + } + + // reduce sum0..sum3 to sum0 + ggml_float sumf; + GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < ne00; ++i) { + sumf += x[i] * x[i]; + } + + // compute scale factor + ggml_float meanf = sumf / ne00; + ggml_float scalef = 1.0 / sqrt(meanf + eps); + GGML_F32_VEC scale = GGML_F32_VEC_SET1(scalef); + + // scale x and copy to y + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ax[j] = GGML_F32_VEC_MUL(ax[j], scale); + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ax[j]); + } + } + + // leftovers + for (int i = np; i < ne00; ++i) { + y[i] = x[i] * scalef; + } + } + } + } +#else // TODO: optimize for (int i03 = 0; i03 < ne03; i03++) { for (int i02 = 0; i02 < ne02; i02++) { @@ -5679,6 +5730,7 @@ static void ggml_compute_forward_rms_norm_f32( } } } +#endif } static void ggml_compute_forward_rms_norm(