add ggml_log operation necessary for cross entropy loss
This commit is contained in:
parent
8cf04fec9d
commit
65d9f7349d
2 changed files with 112 additions and 2 deletions
105
ggml.c
105
ggml.c
|
@ -3734,6 +3734,7 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||||
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
|
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
|
||||||
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
||||||
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
|
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
|
||||||
|
inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = log(x[i]); }
|
||||||
inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
|
inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
|
||||||
inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
|
inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
|
||||||
inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
|
inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
|
||||||
|
@ -3965,6 +3966,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
||||||
"DIV",
|
"DIV",
|
||||||
"SQR",
|
"SQR",
|
||||||
"SQRT",
|
"SQRT",
|
||||||
|
"LOG",
|
||||||
"SUM",
|
"SUM",
|
||||||
"SUM_ROWS",
|
"SUM_ROWS",
|
||||||
"MEAN",
|
"MEAN",
|
||||||
|
@ -4009,7 +4011,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
||||||
"MAP_BINARY",
|
"MAP_BINARY",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 47, "GGML_OP_COUNT != 47");
|
static_assert(GGML_OP_COUNT == 48, "GGML_OP_COUNT != 48");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -4023,6 +4025,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"x/y",
|
"x/y",
|
||||||
"x^2",
|
"x^2",
|
||||||
"√x",
|
"√x",
|
||||||
|
"log(x)",
|
||||||
"Σx",
|
"Σx",
|
||||||
"Σx_k",
|
"Σx_k",
|
||||||
"Σx/n",
|
"Σx/n",
|
||||||
|
@ -4067,7 +4070,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"f(x,y)",
|
"f(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 47, "GGML_OP_COUNT != 47");
|
static_assert(GGML_OP_COUNT == 48, "GGML_OP_COUNT != 48");
|
||||||
|
|
||||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||||
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
||||||
|
@ -5303,6 +5306,41 @@ struct ggml_tensor * ggml_sqrt_inplace(
|
||||||
return ggml_sqrt_impl(ctx, a, true);
|
return ggml_sqrt_impl(ctx, a, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ggml_log
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_log_impl(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
bool inplace) {
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
if (!inplace && (a->grad)) {
|
||||||
|
is_node = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
|
result->op = GGML_OP_LOG;
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src0 = a;
|
||||||
|
result->src1 = NULL;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_log(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_log_impl(ctx, a, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_log_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_log_impl(ctx, a, true);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_sum
|
// ggml_sum
|
||||||
|
|
||||||
struct ggml_tensor * ggml_sum(
|
struct ggml_tensor * ggml_sum(
|
||||||
|
@ -8572,6 +8610,49 @@ static void ggml_compute_forward_sqrt(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ggml_compute_forward_log
|
||||||
|
|
||||||
|
static void ggml_compute_forward_log_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * src0,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(params->ith == 0);
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n = ggml_nrows(src0);
|
||||||
|
const int nc = src0->ne[0];
|
||||||
|
|
||||||
|
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
||||||
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
|
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
ggml_vec_log_f32(nc,
|
||||||
|
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
||||||
|
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_log(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * src0,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_log_f32(params, src0, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_sum
|
// ggml_compute_forward_sum
|
||||||
|
|
||||||
static void ggml_compute_forward_sum_f32(
|
static void ggml_compute_forward_sum_f32(
|
||||||
|
@ -12871,6 +12952,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_sqrt(params, tensor->src0, tensor);
|
ggml_compute_forward_sqrt(params, tensor->src0, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_LOG:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_log(params, tensor->src0, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_sum(params, tensor->src0, tensor);
|
ggml_compute_forward_sum(params, tensor->src0, tensor);
|
||||||
|
@ -13168,6 +13253,21 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
inplace);
|
inplace);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_LOG:
|
||||||
|
{
|
||||||
|
if (src0->grad) {
|
||||||
|
src0->grad =
|
||||||
|
ggml_add_impl(ctx,
|
||||||
|
src0->grad,
|
||||||
|
ggml_div(ctx,
|
||||||
|
tensor->grad,
|
||||||
|
src0),
|
||||||
|
inplace);
|
||||||
|
}
|
||||||
|
if (src1->grad) {
|
||||||
|
// not supported
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
{
|
{
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
|
@ -14006,6 +14106,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
case GGML_OP_SQRT:
|
case GGML_OP_SQRT:
|
||||||
|
case GGML_OP_LOG:
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
|
|
9
ggml.h
9
ggml.h
|
@ -259,6 +259,7 @@ extern "C" {
|
||||||
GGML_OP_DIV,
|
GGML_OP_DIV,
|
||||||
GGML_OP_SQR,
|
GGML_OP_SQR,
|
||||||
GGML_OP_SQRT,
|
GGML_OP_SQRT,
|
||||||
|
GGML_OP_LOG,
|
||||||
GGML_OP_SUM,
|
GGML_OP_SUM,
|
||||||
GGML_OP_SUM_ROWS,
|
GGML_OP_SUM_ROWS,
|
||||||
GGML_OP_MEAN,
|
GGML_OP_MEAN,
|
||||||
|
@ -535,6 +536,14 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_log(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_log_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
// return scalar
|
// return scalar
|
||||||
GGML_API struct ggml_tensor * ggml_sum(
|
GGML_API struct ggml_tensor * ggml_sum(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue