fix op checker

This commit is contained in:
hongruichen 2024-07-18 20:26:05 +08:00
parent 15f5cc450c
commit 665f823748

View file

@ -134,42 +134,6 @@ struct ggml_backend_qnn_buffer_type_context {
// implementation of QNN backend for GGML
//
// =================================================================================================
static bool ggml_qnn_can_handle_op(ggml_backend_qnn_context *ctx, const struct ggml_tensor *tensor) {
if (ggml_is_empty(tensor)) {
return false;
}
if (!qnn::ggml_qnn_unary_op_array()[tensor->op] && !qnn::ggml_qnn_binary_op_array()[tensor->op] &&
(tensor->op != GGML_OP_UNARY ||
qnn::ggml_qnn_unary_op_array()[qnn::kGgmlUnaryOpStart + ggml_get_unary_op(tensor)])) {
return false;
}
const struct ggml_tensor *src0 = tensor->src[0];
const struct ggml_tensor *src1 = tensor->src[1];
if (!src0 || !src1) {
return false;
}
const auto ne00 = src0->ne[0];
const auto ne01 = src0->ne[1];
const auto ne10 = src1->ne[0];
const auto ne11 = src1->ne[1];
// make qnn_get_ggml_tensor_rank and QNN SDK happy
if (ne00 <= 1 || ne01 <= 1 || ne10 <= 1 || ne11 <= 1) {
return false;
}
// TODO: support other quantized data type
if (ggml_is_quantized(src0->type)) {
if (src0->type != GGML_TYPE_Q8_0 && src0->type != GGML_TYPE_Q4_0) {
return false;
}
}
return true;
}
static bool ggml_qnn_compute_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) {
size_t unary_op_idx = tensor->op;
if (tensor->op == GGML_OP_UNARY) {
@ -297,8 +261,8 @@ GGML_CALL static bool ggml_backend_qnn_buffer_is_host(ggml_backend_buffer_type_t
}
GGML_CALL static const char *ggml_backend_qnn_name(ggml_backend_t backend) {
GGML_UNUSED(backend);
return "QNN";
ggml_backend_qnn_context *ctx = (ggml_backend_qnn_context *)backend->context;
return g_qnn_mgr[ctx->device].name;
}
GGML_CALL static void ggml_backend_qnn_free(ggml_backend_t backend) {
@ -353,15 +317,53 @@ GGML_CALL static ggml_status ggml_backend_qnn_graph_compute(ggml_backend_t backe
}
GGML_CALL static bool ggml_backend_qnn_supports_op(ggml_backend_t backend, const ggml_tensor *op) {
ggml_backend_qnn_context *ctx = (ggml_backend_qnn_context *)backend->context;
GGML_UNUSED(backend);
return ggml_qnn_can_handle_op(ctx, op);
if (op->op == GGML_OP_NONE) {
return true;
}
if (op->op == GGML_OP_UNARY) {
if (!qnn::ggml_qnn_unary_op_array()[qnn::kGgmlUnaryOpStart + ggml_get_unary_op(op)]) {
QNN_LOG_DEBUG("unsupported unary op %d", ggml_get_unary_op(op));
return false;
}
if (!op->src[0]) {
QNN_LOG_DEBUG("src0 is nullptr");
return false;
}
} else {
if (!qnn::ggml_qnn_unary_op_array()[op->op] && !qnn::ggml_qnn_binary_op_array()[op->op]) {
QNN_LOG_DEBUG("unsupported op %d", op->op);
return false;
}
if (!op->src[0] || !op->src[1]) {
QNN_LOG_DEBUG("src0 or src1 is nullptr");
return false;
}
}
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_I8:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
break;
default:
QNN_LOG_DEBUG("unsupported src0 type %d", op->src[0]->type);
return false;
}
return true;
}
GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend, const ggml_tensor *tensor) {
ggml_backend_qnn_context *ctx = (ggml_backend_qnn_context *)backend->context;
GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend, const ggml_tensor *op) {
GGML_UNUSED(backend);
return ggml_qnn_can_handle_op(ctx, tensor);
return op->ne[0] > 1 && op->ne[1] > 1;
}
static ggml_backend_i ggml_backend_qnn_interface = {