refactoring: move forward and supports_op into ops file

This commit is contained in:
hongruichen 2024-07-27 12:52:59 +08:00
parent 867c91bfaf
commit 5da73f8085
3 changed files with 97 additions and 86 deletions

View file

@ -114,23 +114,7 @@ struct ggml_backend_qnn_buffer_type_context {
//
// =================================================================================================
static bool ggml_qnn_compute_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) {
size_t unary_op_idx = tensor->op;
if (tensor->op == GGML_OP_UNARY) {
unary_op_idx = qnn::kGgmlUnaryOpStart + ggml_get_unary_op(tensor);
}
auto unary_op = qnn::ggml_qnn_unary_op_array()[unary_op_idx];
if (unary_op) {
return unary_op(ctx, tensor->src[0], tensor);
}
auto binary_op = qnn::ggml_qnn_binary_op_array()[tensor->op];
if (binary_op) {
return binary_op(ctx, tensor->src[0], tensor->src[1], tensor);
}
QNN_LOG_WARN("unsupported op %d", tensor->op);
return false;
return qnn::ggml_qnn_forward(ctx, tensor);
}
static const char *ggml_backend_qnn_buffer_get_name(ggml_backend_buffer_t buffer) {
@ -288,42 +272,7 @@ GGML_CALL static ggml_status ggml_backend_qnn_graph_compute(ggml_backend_t backe
GGML_CALL static bool ggml_backend_qnn_supports_op(ggml_backend_t backend, const ggml_tensor *op) {
GGML_UNUSED(backend);
if (op->op == GGML_OP_UNARY) {
if (!qnn::ggml_qnn_unary_op_array()[qnn::kGgmlUnaryOpStart + ggml_get_unary_op(op)]) {
QNN_LOG_DEBUG("unsupported unary op %d", ggml_get_unary_op(op));
return false;
}
if (!op->src[0]) {
QNN_LOG_DEBUG("src0 is nullptr");
return false;
}
} else if (op->op != GGML_OP_NONE) {
if (!qnn::ggml_qnn_unary_op_array()[op->op] && !qnn::ggml_qnn_binary_op_array()[op->op]) {
QNN_LOG_DEBUG("unsupported op %d", op->op);
return false;
}
if (!op->src[0] || !op->src[1]) {
QNN_LOG_DEBUG("src0 or src1 is nullptr");
return false;
}
}
switch (op->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_I8:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
break;
default:
QNN_LOG_DEBUG("unsupported src0 type %d", op->src[0]->type);
return false;
}
return true;
return qnn::ggml_qnn_supports_op(op);
}
GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend, const ggml_tensor *op) {

View file

@ -56,6 +56,15 @@ bool qnn_is_valid_params(ggml_backend_qnn_context *ctx, const ggml_tensor *src0,
namespace {
typedef bool (*ggml_qnn_unary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_tensor *dst);
typedef bool (*ggml_qnn_binary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src0, ggml_tensor *src1,
ggml_tensor *dst);
typedef const ggml_qnn_unary_op_t (&ggml_qnn_unary_op_array_t)[GGML_OP_COUNT + GGML_UNARY_OP_COUNT];
typedef const ggml_qnn_binary_op_t (&ggml_qnn_binary_op_array_t)[GGML_OP_COUNT];
constexpr const size_t kGgmlUnaryOpStart = GGML_OP_COUNT;
void print_ggml_tensor(const ggml_tensor *tensor) {
QNN_LOG_DEBUG("%15s: type = %i (%5s) ne = %5" PRIi64 " x %5" PRIi64 " x %5" PRIi64 ", nb = (%5zi, %5zi, %5zi)\n",
tensor->name, tensor->type, ggml_type_name(tensor->type), tensor->ne[0], tensor->ne[1], tensor->ne[2],
@ -106,8 +115,8 @@ qnn::ggml_qnn_graph *get_qnn_graph_from_cache(ggml_backend_qnn_context *ctx, siz
GGML_ASSERT(op < (GGML_OP_COUNT + GGML_UNARY_OP_COUNT));
auto &graph_cache = ctx->qnn_graph_cache;
const auto *op_name = op < qnn::kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op))
: ggml_unary_op_name(ggml_unary_op(op - qnn::kGgmlUnaryOpStart));
const auto *op_name =
op < kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op)) : ggml_unary_op_name(ggml_unary_op(op - kGgmlUnaryOpStart));
auto graph_key = get_graph_key<_InputSize, _OutputSize>(op_name, inputs, outputs);
auto it = graph_cache.find(graph_key);
qnn::ggml_qnn_graph *graph_ptr = nullptr;
@ -237,7 +246,7 @@ constexpr const char *kGgmlOpToQnnOp[] = {
static_assert(sizeof(kGgmlOpToQnnOp) / sizeof(kGgmlOpToQnnOp[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT),
"GGML_OP_COUNT does not match the size of the kGgmlOpToQnnOp table");
static_assert(kGgmlOpToQnnOp[GGML_UNARY_OP_GELU + qnn::kGgmlUnaryOpStart] != nullptr,
static_assert(kGgmlOpToQnnOp[GGML_UNARY_OP_GELU + kGgmlUnaryOpStart] != nullptr,
"GGML_UNARY_OP_GELU does not correspond to QNN_OP_GELU");
template <ggml_op _GgmlOp>
@ -281,10 +290,8 @@ bool qnn_unary_op_impl(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_ten
return succeed;
}
} // namespace
qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() {
static constexpr const qnn::ggml_qnn_unary_op_t kQnnOpsTable[] = {
ggml_qnn_unary_op_array_t ggml_qnn_unary_op_array() {
static constexpr const ggml_qnn_unary_op_t kQnnOpsTable[] = {
nullptr, // GGML_OP_NONE
nullptr, // GGML_OP_DUP
nullptr, // GGML_OP_ADD
@ -369,19 +376,19 @@ qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() {
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK
// ggml_unary_op
nullptr, // GGML_UNARY_OP_ABS
nullptr, // GGML_UNARY_OP_SGN
nullptr, // GGML_UNARY_OP_NEG
nullptr, // GGML_UNARY_OP_STEP
nullptr, // GGML_UNARY_OP_TANH
nullptr, // GGML_UNARY_OP_ELU
nullptr, // GGML_UNARY_OP_RELU
nullptr, // GGML_UNARY_OP_SIGMOID
qnn_unary_op_impl<GGML_UNARY_OP_GELU + qnn::kGgmlUnaryOpStart>, // GGML_UNARY_OP_GELU
nullptr, // GGML_UNARY_OP_GELU_QUICK
nullptr, // GGML_UNARY_OP_SILU
nullptr, // GGML_UNARY_OP_HARDSWISH
nullptr, // GGML_UNARY_OP_HARDSIGMOID
nullptr, // GGML_UNARY_OP_ABS
nullptr, // GGML_UNARY_OP_SGN
nullptr, // GGML_UNARY_OP_NEG
nullptr, // GGML_UNARY_OP_STEP
nullptr, // GGML_UNARY_OP_TANH
nullptr, // GGML_UNARY_OP_ELU
nullptr, // GGML_UNARY_OP_RELU
nullptr, // GGML_UNARY_OP_SIGMOID
qnn_unary_op_impl<GGML_UNARY_OP_GELU + kGgmlUnaryOpStart>, // GGML_UNARY_OP_GELU
nullptr, // GGML_UNARY_OP_GELU_QUICK
nullptr, // GGML_UNARY_OP_SILU
nullptr, // GGML_UNARY_OP_HARDSWISH
nullptr, // GGML_UNARY_OP_HARDSIGMOID
};
static_assert(sizeof(kQnnOpsTable) / sizeof(kQnnOpsTable[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT),
@ -389,8 +396,8 @@ qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() {
return kQnnOpsTable;
}
qnn::ggml_qnn_binary_op_array_t qnn::ggml_qnn_binary_op_array() {
static constexpr const qnn::ggml_qnn_binary_op_t kQnnOpsTable[] = {
ggml_qnn_binary_op_array_t ggml_qnn_binary_op_array() {
static constexpr const ggml_qnn_binary_op_t kQnnOpsTable[] = {
nullptr, // GGML_OP_NONE
nullptr, // GGML_OP_DUP
qnn_binary_op_impl<GGML_OP_ADD>, // GGML_OP_ADD
@ -479,3 +486,67 @@ qnn::ggml_qnn_binary_op_array_t qnn::ggml_qnn_binary_op_array() {
"GGML_OP_COUNT does not match the size of the ops table");
return kQnnOpsTable;
}
} // namespace
namespace qnn {
bool ggml_qnn_supports_op(const ggml_tensor *op) {
if (op->op == GGML_OP_UNARY) {
if (!ggml_qnn_unary_op_array()[kGgmlUnaryOpStart + ggml_get_unary_op(op)]) {
QNN_LOG_DEBUG("unsupported unary op %d", ggml_get_unary_op(op));
return false;
}
if (!op->src[0]) {
QNN_LOG_DEBUG("src0 is nullptr");
return false;
}
} else if (op->op != GGML_OP_NONE) {
if (!ggml_qnn_unary_op_array()[op->op] && !ggml_qnn_binary_op_array()[op->op]) {
QNN_LOG_DEBUG("unsupported op %d", op->op);
return false;
}
if (!op->src[0] || !op->src[1]) {
QNN_LOG_DEBUG("src0 or src1 is nullptr");
return false;
}
}
switch (op->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_I8:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
break;
default:
QNN_LOG_DEBUG("unsupported src0 type %d", op->src[0]->type);
return false;
}
return true;
}
bool ggml_qnn_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) {
size_t unary_op_idx = tensor->op;
if (tensor->op == GGML_OP_UNARY) {
unary_op_idx = kGgmlUnaryOpStart + ggml_get_unary_op(tensor);
}
auto unary_op = ggml_qnn_unary_op_array()[unary_op_idx];
if (unary_op) {
return unary_op(ctx, tensor->src[0], tensor);
}
auto binary_op = ggml_qnn_binary_op_array()[tensor->op];
if (binary_op) {
return binary_op(ctx, tensor->src[0], tensor->src[1], tensor);
}
QNN_LOG_WARN("unsupported op %d", tensor->op);
return false;
}
} // namespace qnn

View file

@ -6,16 +6,7 @@
namespace qnn {
typedef bool (*ggml_qnn_unary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_tensor *dst);
typedef bool (*ggml_qnn_binary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src0, ggml_tensor *src1,
ggml_tensor *dst);
typedef const ggml_qnn_unary_op_t (&ggml_qnn_unary_op_array_t)[GGML_OP_COUNT + GGML_UNARY_OP_COUNT];
typedef const ggml_qnn_binary_op_t (&ggml_qnn_binary_op_array_t)[GGML_OP_COUNT];
constexpr const size_t kGgmlUnaryOpStart = GGML_OP_COUNT;
ggml_qnn_unary_op_array_t ggml_qnn_unary_op_array();
ggml_qnn_binary_op_array_t ggml_qnn_binary_op_array();
bool ggml_qnn_supports_op(const ggml_tensor *op);
bool ggml_qnn_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor);
} // namespace qnn