refactoring: move forward and supports_op into ops file
This commit is contained in:
parent
867c91bfaf
commit
5da73f8085
3 changed files with 97 additions and 86 deletions
|
@ -114,23 +114,7 @@ struct ggml_backend_qnn_buffer_type_context {
|
|||
//
|
||||
// =================================================================================================
|
||||
static bool ggml_qnn_compute_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) {
|
||||
size_t unary_op_idx = tensor->op;
|
||||
if (tensor->op == GGML_OP_UNARY) {
|
||||
unary_op_idx = qnn::kGgmlUnaryOpStart + ggml_get_unary_op(tensor);
|
||||
}
|
||||
|
||||
auto unary_op = qnn::ggml_qnn_unary_op_array()[unary_op_idx];
|
||||
if (unary_op) {
|
||||
return unary_op(ctx, tensor->src[0], tensor);
|
||||
}
|
||||
|
||||
auto binary_op = qnn::ggml_qnn_binary_op_array()[tensor->op];
|
||||
if (binary_op) {
|
||||
return binary_op(ctx, tensor->src[0], tensor->src[1], tensor);
|
||||
}
|
||||
|
||||
QNN_LOG_WARN("unsupported op %d", tensor->op);
|
||||
return false;
|
||||
return qnn::ggml_qnn_forward(ctx, tensor);
|
||||
}
|
||||
|
||||
static const char *ggml_backend_qnn_buffer_get_name(ggml_backend_buffer_t buffer) {
|
||||
|
@ -288,42 +272,7 @@ GGML_CALL static ggml_status ggml_backend_qnn_graph_compute(ggml_backend_t backe
|
|||
|
||||
GGML_CALL static bool ggml_backend_qnn_supports_op(ggml_backend_t backend, const ggml_tensor *op) {
|
||||
GGML_UNUSED(backend);
|
||||
|
||||
if (op->op == GGML_OP_UNARY) {
|
||||
if (!qnn::ggml_qnn_unary_op_array()[qnn::kGgmlUnaryOpStart + ggml_get_unary_op(op)]) {
|
||||
QNN_LOG_DEBUG("unsupported unary op %d", ggml_get_unary_op(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!op->src[0]) {
|
||||
QNN_LOG_DEBUG("src0 is nullptr");
|
||||
return false;
|
||||
}
|
||||
} else if (op->op != GGML_OP_NONE) {
|
||||
if (!qnn::ggml_qnn_unary_op_array()[op->op] && !qnn::ggml_qnn_binary_op_array()[op->op]) {
|
||||
QNN_LOG_DEBUG("unsupported op %d", op->op);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!op->src[0] || !op->src[1]) {
|
||||
QNN_LOG_DEBUG("src0 or src1 is nullptr");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
break;
|
||||
default:
|
||||
QNN_LOG_DEBUG("unsupported src0 type %d", op->src[0]->type);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return qnn::ggml_qnn_supports_op(op);
|
||||
}
|
||||
|
||||
GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend, const ggml_tensor *op) {
|
||||
|
|
|
@ -56,6 +56,15 @@ bool qnn_is_valid_params(ggml_backend_qnn_context *ctx, const ggml_tensor *src0,
|
|||
|
||||
namespace {
|
||||
|
||||
typedef bool (*ggml_qnn_unary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_tensor *dst);
|
||||
typedef bool (*ggml_qnn_binary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src0, ggml_tensor *src1,
|
||||
ggml_tensor *dst);
|
||||
|
||||
typedef const ggml_qnn_unary_op_t (&ggml_qnn_unary_op_array_t)[GGML_OP_COUNT + GGML_UNARY_OP_COUNT];
|
||||
typedef const ggml_qnn_binary_op_t (&ggml_qnn_binary_op_array_t)[GGML_OP_COUNT];
|
||||
|
||||
constexpr const size_t kGgmlUnaryOpStart = GGML_OP_COUNT;
|
||||
|
||||
void print_ggml_tensor(const ggml_tensor *tensor) {
|
||||
QNN_LOG_DEBUG("%15s: type = %i (%5s) ne = %5" PRIi64 " x %5" PRIi64 " x %5" PRIi64 ", nb = (%5zi, %5zi, %5zi)\n",
|
||||
tensor->name, tensor->type, ggml_type_name(tensor->type), tensor->ne[0], tensor->ne[1], tensor->ne[2],
|
||||
|
@ -106,8 +115,8 @@ qnn::ggml_qnn_graph *get_qnn_graph_from_cache(ggml_backend_qnn_context *ctx, siz
|
|||
GGML_ASSERT(op < (GGML_OP_COUNT + GGML_UNARY_OP_COUNT));
|
||||
|
||||
auto &graph_cache = ctx->qnn_graph_cache;
|
||||
const auto *op_name = op < qnn::kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op))
|
||||
: ggml_unary_op_name(ggml_unary_op(op - qnn::kGgmlUnaryOpStart));
|
||||
const auto *op_name =
|
||||
op < kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op)) : ggml_unary_op_name(ggml_unary_op(op - kGgmlUnaryOpStart));
|
||||
auto graph_key = get_graph_key<_InputSize, _OutputSize>(op_name, inputs, outputs);
|
||||
auto it = graph_cache.find(graph_key);
|
||||
qnn::ggml_qnn_graph *graph_ptr = nullptr;
|
||||
|
@ -237,7 +246,7 @@ constexpr const char *kGgmlOpToQnnOp[] = {
|
|||
|
||||
static_assert(sizeof(kGgmlOpToQnnOp) / sizeof(kGgmlOpToQnnOp[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT),
|
||||
"GGML_OP_COUNT does not match the size of the kGgmlOpToQnnOp table");
|
||||
static_assert(kGgmlOpToQnnOp[GGML_UNARY_OP_GELU + qnn::kGgmlUnaryOpStart] != nullptr,
|
||||
static_assert(kGgmlOpToQnnOp[GGML_UNARY_OP_GELU + kGgmlUnaryOpStart] != nullptr,
|
||||
"GGML_UNARY_OP_GELU does not correspond to QNN_OP_GELU");
|
||||
|
||||
template <ggml_op _GgmlOp>
|
||||
|
@ -281,10 +290,8 @@ bool qnn_unary_op_impl(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_ten
|
|||
return succeed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() {
|
||||
static constexpr const qnn::ggml_qnn_unary_op_t kQnnOpsTable[] = {
|
||||
ggml_qnn_unary_op_array_t ggml_qnn_unary_op_array() {
|
||||
static constexpr const ggml_qnn_unary_op_t kQnnOpsTable[] = {
|
||||
nullptr, // GGML_OP_NONE
|
||||
nullptr, // GGML_OP_DUP
|
||||
nullptr, // GGML_OP_ADD
|
||||
|
@ -377,7 +384,7 @@ qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() {
|
|||
nullptr, // GGML_UNARY_OP_ELU
|
||||
nullptr, // GGML_UNARY_OP_RELU
|
||||
nullptr, // GGML_UNARY_OP_SIGMOID
|
||||
qnn_unary_op_impl<GGML_UNARY_OP_GELU + qnn::kGgmlUnaryOpStart>, // GGML_UNARY_OP_GELU
|
||||
qnn_unary_op_impl<GGML_UNARY_OP_GELU + kGgmlUnaryOpStart>, // GGML_UNARY_OP_GELU
|
||||
nullptr, // GGML_UNARY_OP_GELU_QUICK
|
||||
nullptr, // GGML_UNARY_OP_SILU
|
||||
nullptr, // GGML_UNARY_OP_HARDSWISH
|
||||
|
@ -389,8 +396,8 @@ qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() {
|
|||
return kQnnOpsTable;
|
||||
}
|
||||
|
||||
qnn::ggml_qnn_binary_op_array_t qnn::ggml_qnn_binary_op_array() {
|
||||
static constexpr const qnn::ggml_qnn_binary_op_t kQnnOpsTable[] = {
|
||||
ggml_qnn_binary_op_array_t ggml_qnn_binary_op_array() {
|
||||
static constexpr const ggml_qnn_binary_op_t kQnnOpsTable[] = {
|
||||
nullptr, // GGML_OP_NONE
|
||||
nullptr, // GGML_OP_DUP
|
||||
qnn_binary_op_impl<GGML_OP_ADD>, // GGML_OP_ADD
|
||||
|
@ -479,3 +486,67 @@ qnn::ggml_qnn_binary_op_array_t qnn::ggml_qnn_binary_op_array() {
|
|||
"GGML_OP_COUNT does not match the size of the ops table");
|
||||
return kQnnOpsTable;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace qnn {
|
||||
|
||||
bool ggml_qnn_supports_op(const ggml_tensor *op) {
|
||||
if (op->op == GGML_OP_UNARY) {
|
||||
if (!ggml_qnn_unary_op_array()[kGgmlUnaryOpStart + ggml_get_unary_op(op)]) {
|
||||
QNN_LOG_DEBUG("unsupported unary op %d", ggml_get_unary_op(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!op->src[0]) {
|
||||
QNN_LOG_DEBUG("src0 is nullptr");
|
||||
return false;
|
||||
}
|
||||
} else if (op->op != GGML_OP_NONE) {
|
||||
if (!ggml_qnn_unary_op_array()[op->op] && !ggml_qnn_binary_op_array()[op->op]) {
|
||||
QNN_LOG_DEBUG("unsupported op %d", op->op);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!op->src[0] || !op->src[1]) {
|
||||
QNN_LOG_DEBUG("src0 or src1 is nullptr");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
break;
|
||||
default:
|
||||
QNN_LOG_DEBUG("unsupported src0 type %d", op->src[0]->type);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ggml_qnn_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) {
|
||||
size_t unary_op_idx = tensor->op;
|
||||
if (tensor->op == GGML_OP_UNARY) {
|
||||
unary_op_idx = kGgmlUnaryOpStart + ggml_get_unary_op(tensor);
|
||||
}
|
||||
|
||||
auto unary_op = ggml_qnn_unary_op_array()[unary_op_idx];
|
||||
if (unary_op) {
|
||||
return unary_op(ctx, tensor->src[0], tensor);
|
||||
}
|
||||
|
||||
auto binary_op = ggml_qnn_binary_op_array()[tensor->op];
|
||||
if (binary_op) {
|
||||
return binary_op(ctx, tensor->src[0], tensor->src[1], tensor);
|
||||
}
|
||||
|
||||
QNN_LOG_WARN("unsupported op %d", tensor->op);
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace qnn
|
||||
|
|
|
@ -6,16 +6,7 @@
|
|||
|
||||
namespace qnn {
|
||||
|
||||
typedef bool (*ggml_qnn_unary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_tensor *dst);
|
||||
typedef bool (*ggml_qnn_binary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src0, ggml_tensor *src1,
|
||||
ggml_tensor *dst);
|
||||
|
||||
typedef const ggml_qnn_unary_op_t (&ggml_qnn_unary_op_array_t)[GGML_OP_COUNT + GGML_UNARY_OP_COUNT];
|
||||
typedef const ggml_qnn_binary_op_t (&ggml_qnn_binary_op_array_t)[GGML_OP_COUNT];
|
||||
|
||||
constexpr const size_t kGgmlUnaryOpStart = GGML_OP_COUNT;
|
||||
|
||||
ggml_qnn_unary_op_array_t ggml_qnn_unary_op_array();
|
||||
ggml_qnn_binary_op_array_t ggml_qnn_binary_op_array();
|
||||
bool ggml_qnn_supports_op(const ggml_tensor *op);
|
||||
bool ggml_qnn_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor);
|
||||
|
||||
} // namespace qnn
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue