Fix position of map ops cases in ggml_compute_forward

This commit is contained in:
KerfuffleV2 2023-04-14 04:01:50 -06:00
parent 7d695973a5
commit 7d03e6e417

25
ggml.c
View file

@ -3677,7 +3677,6 @@ struct ggml_tensor * ggml_dup_inplace(
return ggml_dup_impl(ctx, a, true); return ggml_dup_impl(ctx, a, true);
} }
// ggml_add // ggml_add
struct ggml_tensor * ggml_add_impl( struct ggml_tensor * ggml_add_impl(
@ -9073,18 +9072,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_dup(params, tensor->src0, tensor); ggml_compute_forward_dup(params, tensor->src0, tensor);
} break; } break;
case GGML_OP_MAP_UNARY:
{
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
}
break;
case GGML_OP_MAP_BINARY:
{
const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
}
break;
case GGML_OP_ADD: case GGML_OP_ADD:
{ {
ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor); ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
@ -9224,6 +9211,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor); ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
} break; } break;
case GGML_OP_MAP_UNARY:
{
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
}
break;
case GGML_OP_MAP_BINARY:
{
const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
}
break;
case GGML_OP_NONE: case GGML_OP_NONE:
{ {
// nop // nop