add ggml_log operation necessary for cross entropy loss

This commit is contained in:
xaedes 2023-05-06 17:35:13 +02:00
parent 8cf04fec9d
commit 65d9f7349d
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 112 additions and 2 deletions

105
ggml.c
View file

@ -3734,6 +3734,7 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = log(x[i]); }
inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
@ -3965,6 +3966,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"DIV",
"SQR",
"SQRT",
"LOG",
"SUM",
"SUM_ROWS",
"MEAN",
@ -4009,7 +4011,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"MAP_BINARY",
};
static_assert(GGML_OP_COUNT == 47, "GGML_OP_COUNT != 47");
static_assert(GGML_OP_COUNT == 48, "GGML_OP_COUNT != 48");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -4023,6 +4025,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"x/y",
"x^2",
"√x",
"log(x)",
"Σx",
"Σx_k",
"Σx/n",
@ -4067,7 +4070,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"f(x,y)",
};
static_assert(GGML_OP_COUNT == 47, "GGML_OP_COUNT != 47");
static_assert(GGML_OP_COUNT == 48, "GGML_OP_COUNT != 48");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@ -5303,6 +5306,41 @@ struct ggml_tensor * ggml_sqrt_inplace(
return ggml_sqrt_impl(ctx, a, true);
}
// ggml_log
struct ggml_tensor * ggml_log_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
result->op = GGML_OP_LOG;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL;
return result;
}
struct ggml_tensor * ggml_log(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_log_impl(ctx, a, false);
}
struct ggml_tensor * ggml_log_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_log_impl(ctx, a, true);
}
// ggml_sum
struct ggml_tensor * ggml_sum(
@ -8572,6 +8610,49 @@ static void ggml_compute_forward_sqrt(
}
}
// ggml_compute_forward_log
static void ggml_compute_forward_log_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
GGML_ASSERT(params->ith == 0);
GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
GGML_ASSERT( dst->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_vec_log_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_compute_forward_log(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_log_f32(params, src0, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_sum
static void ggml_compute_forward_sum_f32(
@ -12871,6 +12952,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_sqrt(params, tensor->src0, tensor);
} break;
case GGML_OP_LOG:
{
ggml_compute_forward_log(params, tensor->src0, tensor);
} break;
case GGML_OP_SUM:
{
ggml_compute_forward_sum(params, tensor->src0, tensor);
@ -13168,6 +13253,21 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
inplace);
}
} break;
case GGML_OP_LOG:
{
if (src0->grad) {
src0->grad =
ggml_add_impl(ctx,
src0->grad,
ggml_div(ctx,
tensor->grad,
src0),
inplace);
}
if (src1->grad) {
// not supported
}
} break;
case GGML_OP_SUM:
{
if (src0->grad) {
@ -14006,6 +14106,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
case GGML_OP_DIV:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_LOG:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:

9
ggml.h
View file

@ -259,6 +259,7 @@ extern "C" {
GGML_OP_DIV,
GGML_OP_SQR,
GGML_OP_SQRT,
GGML_OP_LOG,
GGML_OP_SUM,
GGML_OP_SUM_ROWS,
GGML_OP_MEAN,
@ -535,6 +536,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_log(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_log_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// return scalar
GGML_API struct ggml_tensor * ggml_sum(
struct ggml_context * ctx,