Add AVX2 implementation of ggml_compute_forward_rms_norm_f32
This commit is contained in:
parent
4870e455b3
commit
acc36eb0b5
1 changed files with 53 additions and 1 deletions
54
ggml.c
54
ggml.c
|
@ -5651,8 +5651,59 @@ static void ggml_compute_forward_rms_norm_f32(
|
|||
const size_t nb2 = dst->nb[2];
|
||||
const size_t nb3 = dst->nb[3];
|
||||
|
||||
const ggml_float eps = 1e-6f; // TODO: make this a parameter
|
||||
const ggml_float eps = 1e-6; // TODO: make this a parameter
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
const int np = (ne00 & ~(GGML_F32_STEP - 1));
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
for (int i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
// compute sum of squares of x
|
||||
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||
GGML_F32_VEC ax[GGML_F32_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
||||
for (int j = 0; j < GGML_F32_ARR; j++) {
|
||||
ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
|
||||
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ax[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// reduce sum0..sum3 to sum0
|
||||
ggml_float sumf;
|
||||
GGML_F32_VEC_REDUCE(sumf, sum);
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < ne00; ++i) {
|
||||
sumf += x[i] * x[i];
|
||||
}
|
||||
|
||||
// compute scale factor
|
||||
ggml_float meanf = sumf / ne00;
|
||||
ggml_float scalef = 1.0 / sqrt(meanf + eps);
|
||||
GGML_F32_VEC scale = GGML_F32_VEC_SET1(scalef);
|
||||
|
||||
// scale x and copy to y
|
||||
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
||||
for (int j = 0; j < GGML_F32_ARR; j++) {
|
||||
ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
|
||||
ax[j] = GGML_F32_VEC_MUL(ax[j], scale);
|
||||
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ax[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < ne00; ++i) {
|
||||
y[i] = x[i] * scalef;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// TODO: optimize
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
|
@ -5679,6 +5730,7 @@ static void ggml_compute_forward_rms_norm_f32(
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rms_norm(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue