ggml: add map_ternary_f32()

This commit is contained in:
Kamil Tomsik 2023-05-16 17:43:42 +02:00
parent 2a5ee023ad
commit 8fcf31b4f1
2 changed files with 126 additions and 3 deletions

117
ggml.c
View file

@ -3465,9 +3465,10 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"MAP_UNARY",
"MAP_BINARY",
"MAP_TERNARY",
};
static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -3527,7 +3528,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"f(x,y)",
};
static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@ -4034,6 +4035,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
/*.grad =*/ NULL,
/*.src0 =*/ NULL,
/*.src1 =*/ NULL,
/*.src2 =*/ NULL,
/*.opt =*/ { NULL },
/*.n_tasks =*/ 0,
/*.perf_runs =*/ 0,
@ -6421,6 +6423,56 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
}
// ggml_map_ternary
struct ggml_tensor * ggml_map_ternary_impl_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
const ggml_ternary_op_f32_t fun,
bool inplace) {
GGML_ASSERT(ggml_are_same_shape(a, b));
GGML_ASSERT(ggml_are_same_shape(b, c));
bool is_node = false;
if (!inplace && (a->grad || b->grad || c->grad)) {
is_node = true;
}
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
*((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
result->op = GGML_OP_MAP_TERNARY;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = b;
result->src2 = c;
result->opt[0] = addr_tensor;
return result;
}
struct ggml_tensor * ggml_map_ternary_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
const ggml_ternary_op_f32_t fun) {
return ggml_map_ternary_impl_f32(ctx, a, b, c, fun, false);
}
struct ggml_tensor * ggml_map_ternary_inplace_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
const ggml_ternary_op_f32_t fun) {
return ggml_map_ternary_impl_f32(ctx, a, b, c, fun, true);
}
////////////////////////////////////////////////////////////////////////////////
void ggml_set_param(
@ -12628,6 +12680,59 @@ static void ggml_compute_forward_map_binary(
}
}
// ggml_compute_forward_map_ternary
static void ggml_compute_forward_map_ternary_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
struct ggml_tensor * dst,
const ggml_ternary_op_f32_t fun) {
assert(params->ith == 0);
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src1, src2) && ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
assert( dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
assert(src1->nb[0] == sizeof(float));
assert(src2->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
fun(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])),
(float *) ((char *) src1->data + i*(src1->nb[1])),
(float *) ((char *) src2->data + i*(src2->nb[1])));
}
}
static void ggml_compute_forward_map_ternary(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
struct ggml_tensor * dst,
const ggml_ternary_op_f32_t fun) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_map_ternary_f32(params, src0, src1, src2, dst, fun);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
/////////////////////////////////
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@ -12837,6 +12942,12 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
}
break;
case GGML_OP_MAP_TERNARY:
{
const ggml_ternary_op_f32_t fun = *((ggml_ternary_op_f32_t *)tensor->opt[0]->data);
ggml_compute_forward_map_ternary(params, tensor->src0, tensor->src1, tensor->src2, tensor, fun);
}
break;
case GGML_OP_NONE:
{
// nop
@ -13517,6 +13628,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break;
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_TERNARY:
{
GGML_ASSERT(false); // not supported
} break;
@ -14062,6 +14174,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} break;
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_TERNARY:
{
node->n_tasks = 1;
} break;

12
ggml.h
View file

@ -321,6 +321,7 @@ extern "C" {
GGML_OP_MAP_UNARY,
GGML_OP_MAP_BINARY,
GGML_OP_MAP_TERNARY,
GGML_OP_COUNT,
};
@ -358,6 +359,7 @@ extern "C" {
struct ggml_tensor * grad;
struct ggml_tensor * src0;
struct ggml_tensor * src1;
struct ggml_tensor * src2;
struct ggml_tensor * opt[GGML_MAX_OPT];
// thread scheduling
@ -372,7 +374,7 @@ extern "C" {
char name[32];
char padding[16];
char padding[8];
};
// computation graph
@ -931,6 +933,7 @@ extern "C" {
// Mapping operations
typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
typedef void (*ggml_ternary_op_f32_t)(const int, float *, const float *, const float *, const float *);
GGML_API struct ggml_tensor * ggml_map_unary_f32(
struct ggml_context * ctx,
@ -943,6 +946,13 @@ extern "C" {
struct ggml_tensor * b,
ggml_binary_op_f32_t fun);
GGML_API struct ggml_tensor * ggml_map_ternary_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
const ggml_ternary_op_f32_t fun);
//
// automatic differentiation
//