add ggml_cross_entropy_loss with backward pass for faster training

cross entropy loss can also be implemented using softmax and log, but as dedicated operation it is faster and especially avoids unnecessary memory overhead.
This commit is contained in:
xaedes 2023-05-28 21:57:38 +02:00
parent 05cb629c8e
commit 71aaf8dedf
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 391 additions and 2 deletions

377
ggml.c
View file

@ -3339,9 +3339,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"MAP_UNARY", "MAP_UNARY",
"MAP_BINARY", "MAP_BINARY",
"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
}; };
static_assert(GGML_OP_COUNT == 53, "GGML_OP_COUNT != 53"); static_assert(GGML_OP_COUNT == 55, "GGML_OP_COUNT != 55");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -3402,9 +3405,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"f(x)", "f(x)",
"f(x,y)", "f(x,y)",
"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
}; };
static_assert(GGML_OP_COUNT == 53, "GGML_OP_COUNT != 53"); static_assert(GGML_OP_COUNT == 55, "GGML_OP_COUNT != 55");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@ -6347,6 +6353,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
return ggml_map_binary_impl_f32(ctx, a, b, fun, true); return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
} }
// ggml_cross_entropy_loss
struct ggml_tensor * ggml_cross_entropy_loss(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
GGML_ASSERT(ggml_are_same_shape(a, b));
bool is_node = false;
if (a->grad || b->grad) {
is_node = true;
}
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
result->op = GGML_OP_CROSS_ENTROPY_LOSS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = b;
return result;
}
// ggml_cross_entropy_loss_back
struct ggml_tensor * ggml_cross_entropy_loss_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c) {
GGML_ASSERT(ggml_are_same_shape(a, b));
GGML_ASSERT(ggml_is_scalar(c));
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
result->grad = NULL;
result->src0 = a;
result->src1 = b;
result->opt[0] = c;
return result;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void ggml_set_param( void ggml_set_param(
@ -12831,6 +12881,287 @@ static void ggml_compute_forward_map_binary(
} }
} }
// ggml_compute_forward_cross_entropy_loss
static void ggml_compute_forward_cross_entropy_loss_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_scalar(dst));
GGML_ASSERT(ggml_are_same_shape(src0, src1));
const int ith = params->ith;
const int nth = params->nth;
float * sums = (float *) params->wdata;
// TODO: handle transposed/permuted matrices
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
if (params->type == GGML_TASK_INIT) {
if (ith == 0) {
memset(sums, 0, sizeof(float) * (nth + nth * nc));
}
return;
}
if (params->type == GGML_TASK_FINALIZE) {
if (ith == 0) {
float * dp = (float *) dst->data;
ggml_vec_sum_f32(nth, dp, sums);
dp[0] *= -1.0f;
}
return;
}
const float eps = 1e-9f;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
float * st = (float *) params->wdata + nth + ith*nc;
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(s0[i]));
assert(!isnan(s1[i]));
}
#endif
// soft_max
ggml_float sum = 0.0;
{
float max = -INFINITY;
ggml_vec_max_f32(nc, &max, s0);
uint16_t scvt;
for (int i = 0; i < nc; i++) {
if (s0[i] == -INFINITY) {
st[i] = 0.0f;
} else {
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
sum += (ggml_float)val;
st[i] = val;
}
}
assert(sum > 0.0);
sum = 1.0/sum;
}
// avoid log(0) by rescaling from [0..1] to [eps..1]
sum = sum * (1.0f - eps);
ggml_vec_scale_f32(nc, st, sum);
ggml_vec_add1_f32(nc, st, st, eps);
ggml_vec_log_f32(nc, st, st);
ggml_vec_mul_f32(nc, st, st, s1);
ggml_vec_sum_f32(nc, sums + ith, st);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
assert(!isnan(st[i]));
assert(!isinf(st[i]));
}
#endif
}
}
static void ggml_compute_forward_cross_entropy_loss(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_cross_entropy_loss_back
static void ggml_compute_forward_cross_entropy_loss_back_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * opt0,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(opt0));
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
const int64_t ith = params->ith;
const int64_t nth = params->nth;
float * sums = (float *) params->wdata;
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
const float eps = 1e-9f;
// TODO: handle transposed/permuted matrices
const int64_t nc = src0->ne[0];
const int64_t nr = ggml_nrows(src0);
// rows per thread
const int64_t dr = (nr + nth - 1)/nth;
// row range for this thread
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
float * d = (float *) opt0->data;
for (int64_t i1 = ir0; i1 < ir1; i1++) {
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
float * sm = (float *) params->wdata + ith*nc;
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(s0[i]));
assert(!isnan(s1[i]));
}
#endif
// step by step explanation:
{
// forward pass with annotated gradients from backward pass
// (built by going in reverse operation order, adding to gradients of current operation args)
// st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
// from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
// ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
// ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
// ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
// ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
// ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
// ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
// substitute into grad[st1], because we can reuse softmax_back from this point on
// grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
// postorder:
// grad[st1] := softmax(s0)
// grad[st1] := grad[st1]*(1.0 - eps)
// grad[st1] := grad[st1] + eps
// grad[st1] := s1 / grad[st1]
// grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
// src0 gradients by going through softmax_back
// grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
// from softmax_back:
// dxk = yk * (dyk - dot(y, dy))
// dot_y_dy := dot(y, dy)
// dx := dy
// dx := dx - dot_y_dy
// dx := dx * y
// postorder:
// dot_st1_dst1 := dot(st1, grad[st1])
// grad[s0] := grad[st1]
// grad[s0] := grad[s0] - dot_st1_dst1
// grad[s0] := grad[s0] * st1
// prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
// sm := softmax(s0)
// grad[s0] := sm*(1.0 - eps)
// grad[s0] := grad[s0] + eps
// grad[s0] := s1 / grad[s0]
// grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
// dot_st1_dst1 := dot(sm, grad[s0])
// grad[s0] := grad[s0] - dot_st1_dst1
// grad[s0] := grad[s0] * sm
}
// soft_max
ggml_float sum = 0.0;
{
float max = -INFINITY;
ggml_vec_max_f32(nc, &max, s0);
uint16_t scvt;
for (int i = 0; i < nc; i++) {
if (s0[i] == -INFINITY) {
sm[i] = 0.0f;
} else {
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
sum += (ggml_float)val;
sm[i] = val;
}
}
assert(sum > 0.0);
sum = 1.0/sum;
}
float dot_st1_dst1 = 0;
ggml_vec_scale_f32(nc, sm, sum);
ggml_vec_cpy_f32 (nc, ds0, sm);
ggml_vec_scale_f32(nc, ds0, (1.0 - eps));
ggml_vec_add1_f32 (nc, ds0, ds0, eps);
ggml_vec_div_f32 (nc, ds0, s1, ds0);
ggml_vec_scale_f32(nc, ds0, -(1.0 - eps)*d[0]);
ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
ggml_vec_mul_f32 (nc, ds0, ds0, sm);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
assert(!isnan(sm[i]));
assert(!isinf(sm[i]));
assert(!isnan(ds0[i]));
assert(!isinf(ds0[i]));
}
#endif
}
}
static void ggml_compute_forward_cross_entropy_loss_back(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * opt0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
///////////////////////////////// /////////////////////////////////
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@ -13052,6 +13383,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun); ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
} }
break; break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
}
break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
{
ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
}
break;
case GGML_OP_NONE: case GGML_OP_NONE:
{ {
// nop // nop
@ -13677,6 +14018,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{ {
GGML_ASSERT(false); // not supported GGML_ASSERT(false); // not supported
} break; } break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
if (src0->grad) {
src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_cross_entropy_loss_back(ctx,
src0,
src1,
tensor->grad),
inplace);
}
} break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
{
GGML_ASSERT(false); // not supported
} break;
case GGML_OP_NONE: case GGML_OP_NONE:
{ {
// nop // nop
@ -14225,6 +14582,22 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
{ {
node->n_tasks = 1; node->n_tasks = 1;
} break; } break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
node->n_tasks = n_threads;
size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
work_size = MAX(work_size, cur);
} break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
{
node->n_tasks = n_threads;
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
work_size = MAX(work_size, cur);
} break;
case GGML_OP_NONE: case GGML_OP_NONE:
{ {
node->n_tasks = 1; node->n_tasks = 1;

16
ggml.h
View file

@ -322,6 +322,9 @@ extern "C" {
GGML_OP_MAP_UNARY, GGML_OP_MAP_UNARY,
GGML_OP_MAP_BINARY, GGML_OP_MAP_BINARY,
GGML_OP_CROSS_ENTROPY_LOSS,
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
GGML_OP_COUNT, GGML_OP_COUNT,
}; };
@ -972,6 +975,19 @@ extern "C" {
struct ggml_tensor * b, struct ggml_tensor * b,
ggml_binary_op_f32_t fun); ggml_binary_op_f32_t fun);
// loss function
GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c);
// //
// automatic differentiation // automatic differentiation
// //