refactoring: remove internal functions, use op table directly

This commit is contained in:
hongruichen 2024-07-27 13:43:07 +08:00
parent e0c9b34016
commit 8ab1f15fe3

View file

@ -304,9 +304,8 @@ bool qnn_unary_op_impl(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_ten
return succeed;
}
constexpr const ggml_qnn_unary_op_t kQnnUnaryOpsTable[] = {
ggml_qnn_unary_op_array_t ggml_qnn_unary_op_array() {
static constexpr const ggml_qnn_unary_op_t kQnnOpsTable[] = {
nullptr, // GGML_OP_NONE
nullptr, // GGML_OP_DUP
nullptr, // GGML_OP_ADD
@ -406,13 +405,10 @@ ggml_qnn_unary_op_array_t ggml_qnn_unary_op_array() {
nullptr, // GGML_UNARY_OP_HARDSIGMOID
};
static_assert(sizeof(kQnnOpsTable) / sizeof(kQnnOpsTable[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT),
"GGML_OP_COUNT does not match the size of the kQnnOpsTable table");
return kQnnOpsTable;
}
static_assert(sizeof(kQnnUnaryOpsTable) / sizeof(kQnnUnaryOpsTable[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT),
"GGML_OP_COUNT does not match the size of the kQnnUnaryOpsTable table");
ggml_qnn_binary_op_array_t ggml_qnn_binary_op_array() {
static constexpr const ggml_qnn_binary_op_t kQnnOpsTable[] = {
static constexpr const ggml_qnn_binary_op_t kQnnBinaryOpsTable[] = {
nullptr, // GGML_OP_NONE
nullptr, // GGML_OP_DUP
qnn_binary_op_impl<GGML_OP_ADD>, // GGML_OP_ADD
@ -497,10 +493,8 @@ ggml_qnn_binary_op_array_t ggml_qnn_binary_op_array() {
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK
};
static_assert(sizeof(kQnnOpsTable) / sizeof(kQnnOpsTable[0]) == GGML_OP_COUNT,
"GGML_OP_COUNT does not match the size of the ops table");
return kQnnOpsTable;
}
static_assert(sizeof(kQnnBinaryOpsTable) / sizeof(kQnnBinaryOpsTable[0]) == GGML_OP_COUNT,
"GGML_OP_COUNT does not match the size of the kQnnBinaryOpsTable table");
} // namespace
@ -508,7 +502,7 @@ namespace qnn {
bool ggml_qnn_supports_op(const ggml_tensor *op) {
if (op->op == GGML_OP_UNARY) {
if (!ggml_qnn_unary_op_array()[kGgmlUnaryOpStart + ggml_get_unary_op(op)]) {
if (!kQnnUnaryOpsTable[kGgmlUnaryOpStart + ggml_get_unary_op(op)]) {
QNN_LOG_DEBUG("unsupported unary op %d", ggml_get_unary_op(op));
return false;
}
@ -518,7 +512,7 @@ bool ggml_qnn_supports_op(const ggml_tensor *op) {
return false;
}
} else if (op->op != GGML_OP_NONE) {
if (!ggml_qnn_unary_op_array()[op->op] && !ggml_qnn_binary_op_array()[op->op]) {
if (!kQnnUnaryOpsTable[op->op] && !kQnnBinaryOpsTable[op->op]) {
QNN_LOG_DEBUG("unsupported op %d", op->op);
return false;
}
@ -555,12 +549,12 @@ bool ggml_qnn_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor)
unary_op_idx = kGgmlUnaryOpStart + ggml_get_unary_op(tensor);
}
auto unary_op = ggml_qnn_unary_op_array()[unary_op_idx];
auto unary_op = kQnnUnaryOpsTable[unary_op_idx];
if (unary_op) {
return unary_op(ctx, tensor->src[0], tensor);
}
auto binary_op = ggml_qnn_binary_op_array()[tensor->op];
auto binary_op = kQnnBinaryOpsTable[tensor->op];
if (binary_op) {
return binary_op(ctx, tensor->src[0], tensor->src[1], tensor);
}