fix op checker
This commit is contained in:
parent
15f5cc450c
commit
665f823748
1 changed files with 45 additions and 43 deletions
|
@ -134,42 +134,6 @@ struct ggml_backend_qnn_buffer_type_context {
|
||||||
// implementation of QNN backend for GGML
|
// implementation of QNN backend for GGML
|
||||||
//
|
//
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
static bool ggml_qnn_can_handle_op(ggml_backend_qnn_context *ctx, const struct ggml_tensor *tensor) {
|
|
||||||
if (ggml_is_empty(tensor)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!qnn::ggml_qnn_unary_op_array()[tensor->op] && !qnn::ggml_qnn_binary_op_array()[tensor->op] &&
|
|
||||||
(tensor->op != GGML_OP_UNARY ||
|
|
||||||
qnn::ggml_qnn_unary_op_array()[qnn::kGgmlUnaryOpStart + ggml_get_unary_op(tensor)])) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const struct ggml_tensor *src0 = tensor->src[0];
|
|
||||||
const struct ggml_tensor *src1 = tensor->src[1];
|
|
||||||
if (!src0 || !src1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto ne00 = src0->ne[0];
|
|
||||||
const auto ne01 = src0->ne[1];
|
|
||||||
const auto ne10 = src1->ne[0];
|
|
||||||
const auto ne11 = src1->ne[1];
|
|
||||||
// make qnn_get_ggml_tensor_rank and QNN SDK happy
|
|
||||||
if (ne00 <= 1 || ne01 <= 1 || ne10 <= 1 || ne11 <= 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: support other quantized data type
|
|
||||||
if (ggml_is_quantized(src0->type)) {
|
|
||||||
if (src0->type != GGML_TYPE_Q8_0 && src0->type != GGML_TYPE_Q4_0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool ggml_qnn_compute_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) {
|
static bool ggml_qnn_compute_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) {
|
||||||
size_t unary_op_idx = tensor->op;
|
size_t unary_op_idx = tensor->op;
|
||||||
if (tensor->op == GGML_OP_UNARY) {
|
if (tensor->op == GGML_OP_UNARY) {
|
||||||
|
@ -297,8 +261,8 @@ GGML_CALL static bool ggml_backend_qnn_buffer_is_host(ggml_backend_buffer_type_t
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static const char *ggml_backend_qnn_name(ggml_backend_t backend) {
|
GGML_CALL static const char *ggml_backend_qnn_name(ggml_backend_t backend) {
|
||||||
GGML_UNUSED(backend);
|
ggml_backend_qnn_context *ctx = (ggml_backend_qnn_context *)backend->context;
|
||||||
return "QNN";
|
return g_qnn_mgr[ctx->device].name;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static void ggml_backend_qnn_free(ggml_backend_t backend) {
|
GGML_CALL static void ggml_backend_qnn_free(ggml_backend_t backend) {
|
||||||
|
@ -353,15 +317,53 @@ GGML_CALL static ggml_status ggml_backend_qnn_graph_compute(ggml_backend_t backe
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static bool ggml_backend_qnn_supports_op(ggml_backend_t backend, const ggml_tensor *op) {
|
GGML_CALL static bool ggml_backend_qnn_supports_op(ggml_backend_t backend, const ggml_tensor *op) {
|
||||||
ggml_backend_qnn_context *ctx = (ggml_backend_qnn_context *)backend->context;
|
GGML_UNUSED(backend);
|
||||||
|
|
||||||
return ggml_qnn_can_handle_op(ctx, op);
|
if (op->op == GGML_OP_NONE) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op->op == GGML_OP_UNARY) {
|
||||||
|
if (!qnn::ggml_qnn_unary_op_array()[qnn::kGgmlUnaryOpStart + ggml_get_unary_op(op)]) {
|
||||||
|
QNN_LOG_DEBUG("unsupported unary op %d", ggml_get_unary_op(op));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!op->src[0]) {
|
||||||
|
QNN_LOG_DEBUG("src0 is nullptr");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!qnn::ggml_qnn_unary_op_array()[op->op] && !qnn::ggml_qnn_binary_op_array()[op->op]) {
|
||||||
|
QNN_LOG_DEBUG("unsupported op %d", op->op);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!op->src[0] || !op->src[1]) {
|
||||||
|
QNN_LOG_DEBUG("src0 or src1 is nullptr");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (op->src[0]->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_I8:
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
QNN_LOG_DEBUG("unsupported src0 type %d", op->src[0]->type);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend, const ggml_tensor *tensor) {
|
GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend, const ggml_tensor *op) {
|
||||||
ggml_backend_qnn_context *ctx = (ggml_backend_qnn_context *)backend->context;
|
GGML_UNUSED(backend);
|
||||||
|
|
||||||
return ggml_qnn_can_handle_op(ctx, tensor);
|
return op->ne[0] > 1 && op->ne[1] > 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_i ggml_backend_qnn_interface = {
|
static ggml_backend_i ggml_backend_qnn_interface = {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue