bug fixes for cross entropy loss

ggml_cross_entropy_loss: sums where not correctly added in workload of each thread
ggml_cross_entropy_loss_back: simplify backward process, reducing numerical issues

guard usage of exp f16 lookup in cross entropy by #define GGML_CROSS_ENTROPY_EXP_FP16

cross entropy loss is only used once during training, but it is quite sensitive to numerical errors introduced by exp-f16-lookup.
so exp-f16-lookup for cross entropy loss is disabled by default, trading better gradients for very slightly worse runtime performance.
This commit is contained in:
xaedes 2023-07-02 20:55:54 +02:00
parent 97964a4cc9
commit 2c6985f79e
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

103
ggml.c
View file

@ -123,6 +123,7 @@ typedef void * thread_ret_t;
#define GGML_GELU_FP16 #define GGML_GELU_FP16
#define GGML_GELU_QUICK_FP16 #define GGML_GELU_QUICK_FP16
#define GGML_SILU_FP16 #define GGML_SILU_FP16
// #define GGML_CROSS_ENTROPY_EXP_FP16
#define GGML_SOFT_MAX_UNROLL 4 #define GGML_SOFT_MAX_UNROLL 4
#define GGML_VEC_DOT_UNROLL 2 #define GGML_VEC_DOT_UNROLL 2
@ -11486,6 +11487,7 @@ static void ggml_compute_forward_soft_max_back_f32(
// dx = J * dy // dx = J * dy
// dxk = sum_i(Jki * dyi) // dxk = sum_i(Jki * dyi)
// dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
// dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
// dxk = sum_i(-yk*yi * dyi) + yk*dyk // dxk = sum_i(-yk*yi * dyi) + yk*dyk
// dxk = -yk * sum_i(yi * dyi) + yk*dyk // dxk = -yk * sum_i(yi * dyi) + yk*dyk
// dxk = -yk * dot(y, dy) + yk*dyk // dxk = -yk * dot(y, dy) + yk*dyk
@ -13109,6 +13111,7 @@ static void ggml_compute_forward_flash_attn_f32(
if (SS[j] == -INFINITY) { if (SS[j] == -INFINITY) {
SS[j] = 0.0f; SS[j] = 0.0f;
} else { } else {
// const float val = expf(SS[j] - max);
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
memcpy(&scvt[j], &s, sizeof(uint16_t)); memcpy(&scvt[j], &s, sizeof(uint16_t));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
@ -13700,6 +13703,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
if (SR[j] == -INFINITY) { if (SR[j] == -INFINITY) {
SW[j] = 0.0f; SW[j] = 0.0f;
} else { } else {
// const float val = expf(SR[j] - max);
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max); ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
memcpy(&scvt[j], &s, sizeof(uint16_t)); memcpy(&scvt[j], &s, sizeof(uint16_t));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
@ -14317,6 +14321,8 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
const int nc = src0->ne[0]; const int nc = src0->ne[0];
const int nr = ggml_nrows(src0); const int nr = ggml_nrows(src0);
GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
if (params->type == GGML_TASK_INIT) { if (params->type == GGML_TASK_INIT) {
if (ith == 0) { if (ith == 0) {
memset(sums, 0, sizeof(float) * (nth + nth * nc)); memset(sums, 0, sizeof(float) * (nth + nth * nc));
@ -14345,7 +14351,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
for (int i1 = ir0; i1 < ir1; i1++) { for (int i1 = ir0; i1 < ir1; i1++) {
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
float * st = (float *) params->wdata + nth + ith*nc; float * st = ((float *) params->wdata) + nth + ith*nc;
#ifndef NDEBUG #ifndef NDEBUG
for (int i = 0; i < nc; ++i) { for (int i = 0; i < nc; ++i) {
@ -14365,10 +14371,14 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
if (s0[i] == -INFINITY) { if (s0[i] == -INFINITY) {
st[i] = 0.0f; st[i] = 0.0f;
} else { } else {
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max); #ifndef GGML_CROSS_ENTROPY_EXP_FP16
const float s = s0[i] - max;
const float val = expf(s);
#else
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
memcpy(&scvt, &s, sizeof(scvt)); memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
#endif
sum += (ggml_float)val; sum += (ggml_float)val;
st[i] = val; st[i] = val;
} }
@ -14384,7 +14394,9 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
ggml_vec_log_f32(nc, st, st); ggml_vec_log_f32(nc, st, st);
ggml_vec_mul_f32(nc, st, st, s1); ggml_vec_mul_f32(nc, st, st, s1);
ggml_vec_sum_f32(nc, sums + ith, st); float st_sum = 0;
ggml_vec_sum_f32(nc, &st_sum, st);
sums[ith] += st_sum;
#ifndef NDEBUG #ifndef NDEBUG
for (int i = 0; i < nc; ++i) { for (int i = 0; i < nc; ++i) {
@ -14434,7 +14446,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
return; return;
} }
const float eps = 1e-9f; const double eps = 1e-9f;
// TODO: handle transposed/permuted matrices // TODO: handle transposed/permuted matrices
const int64_t nc = src0->ne[0]; const int64_t nc = src0->ne[0];
@ -14453,7 +14465,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
float * sm = (float *) params->wdata + ith*nc;
#ifndef NDEBUG #ifndef NDEBUG
for (int i = 0; i < nc; ++i) { for (int i = 0; i < nc; ++i) {
@ -14462,54 +14473,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
assert(!isnan(s1[i])); assert(!isnan(s1[i]));
} }
#endif #endif
// step by step explanation:
{
//float * sums = (float *) params->wdata;
// forward pass with annotated gradients from backward pass
// (built by going in reverse operation order, adding to gradients of current operation args)
// st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
// from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
// ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
// ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
// ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
// ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
// ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
// ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
// substitute into grad[st1], because we can reuse softmax_back from this point on
// grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
// postorder:
// grad[st1] := softmax(s0)
// grad[st1] := grad[st1]*(1.0 - eps)
// grad[st1] := grad[st1] + eps
// grad[st1] := s1 / grad[st1]
// grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
// src0 gradients by going through softmax_back
// grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
// from softmax_back:
// dxk = yk * (dyk - dot(y, dy))
// dot_y_dy := dot(y, dy)
// dx := dy
// dx := dx - dot_y_dy
// dx := dx * y
// postorder:
// dot_st1_dst1 := dot(st1, grad[st1])
// grad[s0] := grad[st1]
// grad[s0] := grad[s0] - dot_st1_dst1
// grad[s0] := grad[s0] * st1
// prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
// sm := softmax(s0)
// grad[s0] := sm*(1.0 - eps)
// grad[s0] := grad[s0] + eps
// grad[s0] := s1 / grad[s0]
// grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
// dot_st1_dst1 := dot(sm, grad[s0])
// grad[s0] := grad[s0] - dot_st1_dst1
// grad[s0] := grad[s0] * sm
}
// soft_max // soft_max
ggml_float sum = 0.0; ggml_float sum = 0.0;
@ -14520,36 +14483,34 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
uint16_t scvt; uint16_t scvt;
for (int i = 0; i < nc; i++) { for (int i = 0; i < nc; i++) {
if (s0[i] == -INFINITY) { if (s0[i] == -INFINITY) {
sm[i] = 0.0f; ds0[i] = 0.0f;
} else { } else {
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max); #ifndef GGML_CROSS_ENTROPY_EXP_FP16
const float s = s0[i] - max;
const float val = expf(s);
#else
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
memcpy(&scvt, &s, sizeof(scvt)); memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
#endif
sum += (ggml_float)val; sum += (ggml_float)val;
sm[i] = val; ds0[i] = val;
} }
} }
assert(sum > 0.0); assert(sum > 0.0);
sum = 1.0/sum; sum = (1.0 - eps)/sum;
} }
float dot_st1_dst1 = 0; // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
ggml_vec_scale_f32(nc, sm, sum); ggml_vec_scale_f32(nc, ds0, sum);
ggml_vec_cpy_f32 (nc, ds0, sm); ggml_vec_add1_f32(nc, ds0, ds0, eps);
ggml_vec_scale_f32(nc, ds0, (1.0f - eps)); ggml_vec_sub_f32(nc, ds0, ds0, s1);
ggml_vec_add1_f32 (nc, ds0, ds0, eps); ggml_vec_scale_f32(nc, ds0, d[0]);
ggml_vec_div_f32 (nc, ds0, s1, ds0);
ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
ggml_vec_mul_f32 (nc, ds0, ds0, sm);
#ifndef NDEBUG #ifndef NDEBUG
for (int i = 0; i < nc; ++i) { for (int i = 0; i < nc; ++i) {
assert(!isnan(sm[i]));
assert(!isinf(sm[i]));
assert(!isnan(ds0[i])); assert(!isnan(ds0[i]));
assert(!isinf(ds0[i])); assert(!isinf(ds0[i]));
} }
@ -16445,10 +16406,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
{ {
n_tasks = n_threads; n_tasks = n_threads;
size_t cur = ggml_type_size(node->type)*node->src[0]->ne[0]*n_tasks;
work_size = MAX(work_size, cur);
} break; } break;
case GGML_OP_NONE: case GGML_OP_NONE:
{ {