From 665f823748d13feab4cc747caec1d6896e83ec87 Mon Sep 17 00:00:00 2001 From: hongruichen Date: Thu, 18 Jul 2024 20:26:05 +0800 Subject: [PATCH] fix op checker --- ggml/src/ggml-qnn.cpp | 88 ++++++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-qnn.cpp b/ggml/src/ggml-qnn.cpp index 282a3d859..3f228935c 100644 --- a/ggml/src/ggml-qnn.cpp +++ b/ggml/src/ggml-qnn.cpp @@ -134,42 +134,6 @@ struct ggml_backend_qnn_buffer_type_context { // implementation of QNN backend for GGML // // ================================================================================================= -static bool ggml_qnn_can_handle_op(ggml_backend_qnn_context *ctx, const struct ggml_tensor *tensor) { - if (ggml_is_empty(tensor)) { - return false; - } - - if (!qnn::ggml_qnn_unary_op_array()[tensor->op] && !qnn::ggml_qnn_binary_op_array()[tensor->op] && - (tensor->op != GGML_OP_UNARY || - qnn::ggml_qnn_unary_op_array()[qnn::kGgmlUnaryOpStart + ggml_get_unary_op(tensor)])) { - return false; - } - - const struct ggml_tensor *src0 = tensor->src[0]; - const struct ggml_tensor *src1 = tensor->src[1]; - if (!src0 || !src1) { - return false; - } - - const auto ne00 = src0->ne[0]; - const auto ne01 = src0->ne[1]; - const auto ne10 = src1->ne[0]; - const auto ne11 = src1->ne[1]; - // make qnn_get_ggml_tensor_rank and QNN SDK happy - if (ne00 <= 1 || ne01 <= 1 || ne10 <= 1 || ne11 <= 1) { - return false; - } - - // TODO: support other quantized data type - if (ggml_is_quantized(src0->type)) { - if (src0->type != GGML_TYPE_Q8_0 && src0->type != GGML_TYPE_Q4_0) { - return false; - } - } - - return true; -} - static bool ggml_qnn_compute_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) { size_t unary_op_idx = tensor->op; if (tensor->op == GGML_OP_UNARY) { @@ -297,8 +261,8 @@ GGML_CALL static bool ggml_backend_qnn_buffer_is_host(ggml_backend_buffer_type_t } GGML_CALL static const char *ggml_backend_qnn_name(ggml_backend_t backend) { - GGML_UNUSED(backend); - return "QNN"; + ggml_backend_qnn_context *ctx = (ggml_backend_qnn_context *)backend->context; + return g_qnn_mgr[ctx->device].name; } GGML_CALL static void ggml_backend_qnn_free(ggml_backend_t backend) { @@ -353,15 +317,53 @@ GGML_CALL static ggml_status ggml_backend_qnn_graph_compute(ggml_backend_t backe } GGML_CALL static bool ggml_backend_qnn_supports_op(ggml_backend_t backend, const ggml_tensor *op) { - ggml_backend_qnn_context *ctx = (ggml_backend_qnn_context *)backend->context; + GGML_UNUSED(backend); - return ggml_qnn_can_handle_op(ctx, op); + if (op->op == GGML_OP_NONE) { + return true; + } + + if (op->op == GGML_OP_UNARY) { + if (!qnn::ggml_qnn_unary_op_array()[qnn::kGgmlUnaryOpStart + ggml_get_unary_op(op)]) { + QNN_LOG_DEBUG("unsupported unary op %d", ggml_get_unary_op(op)); + return false; + } + + if (!op->src[0]) { + QNN_LOG_DEBUG("src0 is nullptr"); + return false; + } + } else { + if (!qnn::ggml_qnn_unary_op_array()[op->op] && !qnn::ggml_qnn_binary_op_array()[op->op]) { + QNN_LOG_DEBUG("unsupported op %d", op->op); + return false; + } + + if (!op->src[0] || !op->src[1]) { + QNN_LOG_DEBUG("src0 or src1 is nullptr"); + return false; + } + } + + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_I8: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + break; + default: + QNN_LOG_DEBUG("unsupported src0 type %d", op->src[0]->type); + return false; + } + + return true; } -GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend, const ggml_tensor *tensor) { - ggml_backend_qnn_context *ctx = (ggml_backend_qnn_context *)backend->context; +GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend, const ggml_tensor *op) { + GGML_UNUSED(backend); - return ggml_qnn_can_handle_op(ctx, tensor); + return op->ne[0] > 1 && op->ne[1] > 1; } static ggml_backend_i ggml_backend_qnn_interface = {