ggml : fix soft_max with bias on CPU

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-19 18:03:56 +03:00
parent 3badef1fe1
commit 871fcb6e10
No known key found for this signature in database
GPG key ID: BF970631944C16B7
2 changed files with 9 additions and 3 deletions

4
ggml.c
View file

@ -12410,7 +12410,7 @@ static void ggml_compute_forward_soft_max_f32(
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
float * pos = src2 ? (float *) src2->data : src0->data; ggml_fp16_t * pos = src2 ? (ggml_fp16_t *) src2->data : src0->data;
for (int i1 = ir0; i1 < ir1; i1++) { for (int i1 = ir0; i1 < ir1; i1++) {
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
@ -12433,7 +12433,7 @@ static void ggml_compute_forward_soft_max_f32(
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
for (int i = 0; i < nc; i++) { for (int i = 0; i < nc; i++) {
wp[i] = wp[i] + slope*pos[i]; wp[i] = wp[i] + slope*ggml_fp16_to_fp32(pos[i]);
} }
} }

View file

@ -1103,6 +1103,12 @@ struct test_soft_max : public test_case {
return VARS_TO_STR5(type, ne, mask, scale, max_bias); return VARS_TO_STR5(type, ne, mask, scale, max_bias);
} }
// the 1024 test with bias occasionally fails:
// SOFT_MAX(type=f32,ne=[1024,16,1,1],mask=1,scale=1.000000,max_bias=8.000000): [SOFT_MAX] NMSE = 0.000000103 > 0.000000100 FAIL
virtual double max_nmse_err() override {
return 1e-6;
}
test_soft_max(ggml_type type = GGML_TYPE_F32, test_soft_max(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 10}, std::array<int64_t, 4> ne = {10, 10, 10, 10},
bool mask = false, bool mask = false,