diff --git a/ggml.c b/ggml.c index d620cd11f..01b19aa6a 100644 --- a/ggml.c +++ b/ggml.c @@ -2712,9 +2712,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "FLASH_ATTN", "FLASH_FF", + + "MAP_UNARY", + "MAP_BINARY", }; -static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36"); +static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2757,9 +2760,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_attn(x)", "flash_ff(x)", + + "f(x)", + "f(x,y)", }; -static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36"); +static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -3671,6 +3677,92 @@ struct ggml_tensor * ggml_dup_inplace( return ggml_dup_impl(ctx, a, true); } + +// ggml_map_binary + +struct ggml_tensor * ggml_map_binary_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + void (*const fun)(int, float *, float *, float *), + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); + *((void **)addr_tensor->data) = fun; + struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_MAP_BINARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + result->opt[0] = addr_tensor; + + return result; +} + +struct ggml_tensor * ggml_map_binary( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + void (*const fun)(int, float *, float *, float *)) { + return ggml_map_binary_impl(ctx, a, b, fun, false); +} + +struct ggml_tensor * ggml_map_binary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + void (*const fun)(int, float *, float *, float *)) { + return ggml_map_binary_impl(ctx, a, b, fun, true); +} + +// ggml_map_unary + +struct ggml_tensor * ggml_map_unary_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + void (*const fun)(int, float *, float *), + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + is_node = true; + } + + struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); + *((void **)addr_tensor->data) = fun; + struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_MAP_UNARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->opt[0] = addr_tensor; + + return result; +} + +struct ggml_tensor * ggml_map_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + void (*const fun)(int, float *, float *)) { + return ggml_map_unary_impl(ctx, a, fun, false); +} + +struct ggml_tensor * ggml_map_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + void (*const fun)(int, float *, float *)) { + return ggml_map_unary_impl(ctx, a, fun, true); +} + + // ggml_add struct ggml_tensor * ggml_add_impl( @@ -5329,6 +5421,111 @@ static void ggml_compute_forward_dup( } } +// ggml_compute_forward_map_unary + +static void ggml_compute_forward_map_unary_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst, + void (*const fun)(int, float *, float *)) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + + +static void ggml_compute_forward_map_unary( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst, + void (*const fun)(int, float *, float *)) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_map_unary_f32(params, src0, dst, fun); + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_map_binary + +static void ggml_compute_forward_map_binary_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst, + void (*const fun)(int, float *, float *, float *)) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + + +static void ggml_compute_forward_map_binary( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst, + void (*const fun)(int, float *, float *, float *)) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun); + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_add static void ggml_compute_forward_add_f32( @@ -8877,7 +9074,19 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_dup(params, tensor->src0, tensor); } break; - case GGML_OP_ADD: + case GGML_OP_MAP_UNARY: + { + void (*const fun)(int, float *, float *) = *((void **)tensor->opt[0]->data); + ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun); + } + break; + case GGML_OP_MAP_BINARY: + { + void (*const fun)(int, float *, float *, float *) = *((void **)tensor->opt[0]->data); + ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun); + } + break; + case GGML_OP_ADD: { ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor); } break; diff --git a/ggml.h b/ggml.h index c06c09e06..341711c57 100644 --- a/ggml.h +++ b/ggml.h @@ -253,6 +253,9 @@ enum ggml_op { GGML_OP_FLASH_ATTN, GGML_OP_FLASH_FF, + GGML_OP_MAP_UNARY, + GGML_OP_MAP_BINARY, + GGML_OP_COUNT, }; @@ -419,6 +422,17 @@ struct ggml_tensor * ggml_dup( struct ggml_context * ctx, struct ggml_tensor * a); +struct ggml_tensor *ggml_map_unary( + struct ggml_context *ctx, + struct ggml_tensor *a, + void (*const fun)(int, float *, float *)); + +struct ggml_tensor *ggml_map_binary( + struct ggml_context *ctx, + struct ggml_tensor *a, + struct ggml_tensor *b, + void (*const fun)(int, float *, float *, float *)); + struct ggml_tensor * ggml_add( struct ggml_context * ctx, struct ggml_tensor * a,