feat: implemented sigmoid function (ggml/806)
* added sigmoid function * implemented metal kernel for sigmoid * implemented cuda kernel for sigmoid * added sigmoid unary op and incremented count
This commit is contained in:
parent
ef0d5e3ec9
commit
f5ef34e428
7 changed files with 136 additions and 1 deletions
73
ggml.c
73
ggml.c
|
@ -1949,6 +1949,7 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) {
|
|||
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
|
||||
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
||||
inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
|
||||
inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
|
||||
// TODO: optimize performance
|
||||
inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
|
||||
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
|
||||
|
@ -2329,6 +2330,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
|||
"TANH",
|
||||
"ELU",
|
||||
"RELU",
|
||||
"SIGMOID",
|
||||
"GELU",
|
||||
"GELU_QUICK",
|
||||
"SILU",
|
||||
|
@ -2336,7 +2338,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
|||
"HARDSIGMOID",
|
||||
};
|
||||
|
||||
static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12");
|
||||
static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
|
||||
|
||||
|
||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||
|
@ -4561,6 +4563,20 @@ struct ggml_tensor * ggml_leaky_relu(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_sigmoid
|
||||
|
||||
struct ggml_tensor * ggml_sigmoid(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_sigmoid_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
|
||||
}
|
||||
|
||||
// ggml_gelu
|
||||
|
||||
struct ggml_tensor * ggml_gelu(
|
||||
|
@ -10852,6 +10868,52 @@ static void ggml_compute_forward_relu(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_sigmoid
|
||||
|
||||
static void ggml_compute_forward_sigmoid_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(params->ith == 0);
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
assert(dst->nb[0] == sizeof(float));
|
||||
assert(src0->nb[0] == sizeof(float));
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
ggml_vec_sigmoid_f32(nc,
|
||||
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_sigmoid(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_sigmoid_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_gelu
|
||||
|
||||
static void ggml_compute_forward_gelu_f32(
|
||||
|
@ -16617,6 +16679,10 @@ static void ggml_compute_forward_unary(
|
|||
{
|
||||
ggml_compute_forward_relu(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
{
|
||||
ggml_compute_forward_sigmoid(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_GELU:
|
||||
{
|
||||
ggml_compute_forward_gelu(params, dst);
|
||||
|
@ -18601,6 +18667,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_UNARY_OP_GELU:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
|
@ -19130,6 +19200,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
|||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
|
||||
case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue