Add AVX2 implementation of ggml_compute_forward_rms_norm_f32

This commit is contained in:
Slaren 2023-03-24 01:10:46 +01:00
parent 4870e455b3
commit acc36eb0b5

54
ggml.c
View file

@ -5651,8 +5651,59 @@ static void ggml_compute_forward_rms_norm_f32(
const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3];
const ggml_float eps = 1e-6f; // TODO: make this a parameter
const ggml_float eps = 1e-6; // TODO: make this a parameter
#if defined(GGML_SIMD)
const int np = (ne00 & ~(GGML_F32_STEP - 1));
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
// compute sum of squares of x
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
GGML_F32_VEC ax[GGML_F32_ARR];
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ax[j]);
}
}
// reduce sum0..sum3 to sum0
ggml_float sumf;
GGML_F32_VEC_REDUCE(sumf, sum);
// leftovers
for (int i = np; i < ne00; ++i) {
sumf += x[i] * x[i];
}
// compute scale factor
ggml_float meanf = sumf / ne00;
ggml_float scalef = 1.0 / sqrt(meanf + eps);
GGML_F32_VEC scale = GGML_F32_VEC_SET1(scalef);
// scale x and copy to y
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
ax[j] = GGML_F32_VEC_MUL(ax[j], scale);
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ax[j]);
}
}
// leftovers
for (int i = np; i < ne00; ++i) {
y[i] = x[i] * scalef;
}
}
}
}
#else
// TODO: optimize
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
@ -5679,6 +5730,7 @@ static void ggml_compute_forward_rms_norm_f32(
}
}
}
#endif
}
static void ggml_compute_forward_rms_norm(