bug fixes for cross entropy loss
ggml_cross_entropy_loss: sums where not correctly added in workload of each thread ggml_cross_entropy_loss_back: simplify backward process, reducing numerical issues guard usage of exp f16 lookup in cross entropy by #define GGML_CROSS_ENTROPY_EXP_FP16 cross entropy loss is only used once during training, but it is quite sensitive to numerical errors introduced by exp-f16-lookup. so exp-f16-lookup for cross entropy loss is disabled by default, trading better gradients for very slightly worse runtime performance.
This commit is contained in:
parent
97964a4cc9
commit
2c6985f79e
1 changed files with 30 additions and 73 deletions
101
ggml.c
101
ggml.c
|
@ -123,6 +123,7 @@ typedef void * thread_ret_t;
|
||||||
#define GGML_GELU_FP16
|
#define GGML_GELU_FP16
|
||||||
#define GGML_GELU_QUICK_FP16
|
#define GGML_GELU_QUICK_FP16
|
||||||
#define GGML_SILU_FP16
|
#define GGML_SILU_FP16
|
||||||
|
// #define GGML_CROSS_ENTROPY_EXP_FP16
|
||||||
|
|
||||||
#define GGML_SOFT_MAX_UNROLL 4
|
#define GGML_SOFT_MAX_UNROLL 4
|
||||||
#define GGML_VEC_DOT_UNROLL 2
|
#define GGML_VEC_DOT_UNROLL 2
|
||||||
|
@ -11486,6 +11487,7 @@ static void ggml_compute_forward_soft_max_back_f32(
|
||||||
// dx = J * dy
|
// dx = J * dy
|
||||||
// dxk = sum_i(Jki * dyi)
|
// dxk = sum_i(Jki * dyi)
|
||||||
// dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
|
// dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
|
||||||
|
// dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
|
||||||
// dxk = sum_i(-yk*yi * dyi) + yk*dyk
|
// dxk = sum_i(-yk*yi * dyi) + yk*dyk
|
||||||
// dxk = -yk * sum_i(yi * dyi) + yk*dyk
|
// dxk = -yk * sum_i(yi * dyi) + yk*dyk
|
||||||
// dxk = -yk * dot(y, dy) + yk*dyk
|
// dxk = -yk * dot(y, dy) + yk*dyk
|
||||||
|
@ -13109,6 +13111,7 @@ static void ggml_compute_forward_flash_attn_f32(
|
||||||
if (SS[j] == -INFINITY) {
|
if (SS[j] == -INFINITY) {
|
||||||
SS[j] = 0.0f;
|
SS[j] = 0.0f;
|
||||||
} else {
|
} else {
|
||||||
|
// const float val = expf(SS[j] - max);
|
||||||
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
|
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
|
||||||
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
||||||
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
||||||
|
@ -13700,6 +13703,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
if (SR[j] == -INFINITY) {
|
if (SR[j] == -INFINITY) {
|
||||||
SW[j] = 0.0f;
|
SW[j] = 0.0f;
|
||||||
} else {
|
} else {
|
||||||
|
// const float val = expf(SR[j] - max);
|
||||||
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
||||||
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
||||||
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
||||||
|
@ -14317,6 +14321,8 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
||||||
const int nc = src0->ne[0];
|
const int nc = src0->ne[0];
|
||||||
const int nr = ggml_nrows(src0);
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT) {
|
if (params->type == GGML_TASK_INIT) {
|
||||||
if (ith == 0) {
|
if (ith == 0) {
|
||||||
memset(sums, 0, sizeof(float) * (nth + nth * nc));
|
memset(sums, 0, sizeof(float) * (nth + nth * nc));
|
||||||
|
@ -14345,7 +14351,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||||
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
||||||
float * st = (float *) params->wdata + nth + ith*nc;
|
float * st = ((float *) params->wdata) + nth + ith*nc;
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int i = 0; i < nc; ++i) {
|
for (int i = 0; i < nc; ++i) {
|
||||||
|
@ -14365,10 +14371,14 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
||||||
if (s0[i] == -INFINITY) {
|
if (s0[i] == -INFINITY) {
|
||||||
st[i] = 0.0f;
|
st[i] = 0.0f;
|
||||||
} else {
|
} else {
|
||||||
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
#ifndef GGML_CROSS_ENTROPY_EXP_FP16
|
||||||
|
const float s = s0[i] - max;
|
||||||
|
const float val = expf(s);
|
||||||
|
#else
|
||||||
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
||||||
memcpy(&scvt, &s, sizeof(scvt));
|
memcpy(&scvt, &s, sizeof(scvt));
|
||||||
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
||||||
|
#endif
|
||||||
sum += (ggml_float)val;
|
sum += (ggml_float)val;
|
||||||
st[i] = val;
|
st[i] = val;
|
||||||
}
|
}
|
||||||
|
@ -14384,7 +14394,9 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
||||||
ggml_vec_log_f32(nc, st, st);
|
ggml_vec_log_f32(nc, st, st);
|
||||||
ggml_vec_mul_f32(nc, st, st, s1);
|
ggml_vec_mul_f32(nc, st, st, s1);
|
||||||
|
|
||||||
ggml_vec_sum_f32(nc, sums + ith, st);
|
float st_sum = 0;
|
||||||
|
ggml_vec_sum_f32(nc, &st_sum, st);
|
||||||
|
sums[ith] += st_sum;
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int i = 0; i < nc; ++i) {
|
for (int i = 0; i < nc; ++i) {
|
||||||
|
@ -14434,7 +14446,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float eps = 1e-9f;
|
const double eps = 1e-9f;
|
||||||
|
|
||||||
// TODO: handle transposed/permuted matrices
|
// TODO: handle transposed/permuted matrices
|
||||||
const int64_t nc = src0->ne[0];
|
const int64_t nc = src0->ne[0];
|
||||||
|
@ -14453,7 +14465,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
||||||
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||||
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||||
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
||||||
float * sm = (float *) params->wdata + ith*nc;
|
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int i = 0; i < nc; ++i) {
|
for (int i = 0; i < nc; ++i) {
|
||||||
|
@ -14462,54 +14473,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
||||||
assert(!isnan(s1[i]));
|
assert(!isnan(s1[i]));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
// step by step explanation:
|
|
||||||
{
|
|
||||||
//float * sums = (float *) params->wdata;
|
|
||||||
|
|
||||||
// forward pass with annotated gradients from backward pass
|
|
||||||
// (built by going in reverse operation order, adding to gradients of current operation args)
|
|
||||||
// st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
|
|
||||||
// from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
|
||||||
// ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
|
|
||||||
// ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
|
|
||||||
// ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
|
|
||||||
// ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
|
|
||||||
// ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
|
|
||||||
// ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
|
|
||||||
|
|
||||||
// substitute into grad[st1], because we can reuse softmax_back from this point on
|
|
||||||
// grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
|
|
||||||
// postorder:
|
|
||||||
// grad[st1] := softmax(s0)
|
|
||||||
// grad[st1] := grad[st1]*(1.0 - eps)
|
|
||||||
// grad[st1] := grad[st1] + eps
|
|
||||||
// grad[st1] := s1 / grad[st1]
|
|
||||||
// grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
|
|
||||||
|
|
||||||
// src0 gradients by going through softmax_back
|
|
||||||
// grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
|
||||||
// from softmax_back:
|
|
||||||
// dxk = yk * (dyk - dot(y, dy))
|
|
||||||
// dot_y_dy := dot(y, dy)
|
|
||||||
// dx := dy
|
|
||||||
// dx := dx - dot_y_dy
|
|
||||||
// dx := dx * y
|
|
||||||
// postorder:
|
|
||||||
// dot_st1_dst1 := dot(st1, grad[st1])
|
|
||||||
// grad[s0] := grad[st1]
|
|
||||||
// grad[s0] := grad[s0] - dot_st1_dst1
|
|
||||||
// grad[s0] := grad[s0] * st1
|
|
||||||
|
|
||||||
// prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
|
|
||||||
// sm := softmax(s0)
|
|
||||||
// grad[s0] := sm*(1.0 - eps)
|
|
||||||
// grad[s0] := grad[s0] + eps
|
|
||||||
// grad[s0] := s1 / grad[s0]
|
|
||||||
// grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
|
|
||||||
// dot_st1_dst1 := dot(sm, grad[s0])
|
|
||||||
// grad[s0] := grad[s0] - dot_st1_dst1
|
|
||||||
// grad[s0] := grad[s0] * sm
|
|
||||||
}
|
|
||||||
|
|
||||||
// soft_max
|
// soft_max
|
||||||
ggml_float sum = 0.0;
|
ggml_float sum = 0.0;
|
||||||
|
@ -14520,36 +14483,34 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
||||||
uint16_t scvt;
|
uint16_t scvt;
|
||||||
for (int i = 0; i < nc; i++) {
|
for (int i = 0; i < nc; i++) {
|
||||||
if (s0[i] == -INFINITY) {
|
if (s0[i] == -INFINITY) {
|
||||||
sm[i] = 0.0f;
|
ds0[i] = 0.0f;
|
||||||
} else {
|
} else {
|
||||||
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
#ifndef GGML_CROSS_ENTROPY_EXP_FP16
|
||||||
|
const float s = s0[i] - max;
|
||||||
|
const float val = expf(s);
|
||||||
|
#else
|
||||||
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
||||||
memcpy(&scvt, &s, sizeof(scvt));
|
memcpy(&scvt, &s, sizeof(scvt));
|
||||||
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
||||||
|
#endif
|
||||||
sum += (ggml_float)val;
|
sum += (ggml_float)val;
|
||||||
sm[i] = val;
|
ds0[i] = val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(sum > 0.0);
|
assert(sum > 0.0);
|
||||||
sum = 1.0/sum;
|
sum = (1.0 - eps)/sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
float dot_st1_dst1 = 0;
|
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
|
||||||
ggml_vec_scale_f32(nc, sm, sum);
|
ggml_vec_scale_f32(nc, ds0, sum);
|
||||||
ggml_vec_cpy_f32 (nc, ds0, sm);
|
|
||||||
ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
|
|
||||||
ggml_vec_add1_f32(nc, ds0, ds0, eps);
|
ggml_vec_add1_f32(nc, ds0, ds0, eps);
|
||||||
ggml_vec_div_f32 (nc, ds0, s1, ds0);
|
ggml_vec_sub_f32(nc, ds0, ds0, s1);
|
||||||
ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
|
ggml_vec_scale_f32(nc, ds0, d[0]);
|
||||||
ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
|
|
||||||
ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
|
|
||||||
ggml_vec_mul_f32 (nc, ds0, ds0, sm);
|
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int i = 0; i < nc; ++i) {
|
for (int i = 0; i < nc; ++i) {
|
||||||
assert(!isnan(sm[i]));
|
|
||||||
assert(!isinf(sm[i]));
|
|
||||||
assert(!isnan(ds0[i]));
|
assert(!isnan(ds0[i]));
|
||||||
assert(!isinf(ds0[i]));
|
assert(!isinf(ds0[i]));
|
||||||
}
|
}
|
||||||
|
@ -16445,10 +16406,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
|
|
||||||
size_t cur = ggml_type_size(node->type)*node->src[0]->ne[0]*n_tasks;
|
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue