refactoring: move forward and supports_op into ops file
This commit is contained in:
parent
867c91bfaf
commit
5da73f8085
3 changed files with 97 additions and 86 deletions
|
@ -114,23 +114,7 @@ struct ggml_backend_qnn_buffer_type_context {
|
||||||
//
|
//
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
static bool ggml_qnn_compute_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) {
|
static bool ggml_qnn_compute_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) {
|
||||||
size_t unary_op_idx = tensor->op;
|
return qnn::ggml_qnn_forward(ctx, tensor);
|
||||||
if (tensor->op == GGML_OP_UNARY) {
|
|
||||||
unary_op_idx = qnn::kGgmlUnaryOpStart + ggml_get_unary_op(tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto unary_op = qnn::ggml_qnn_unary_op_array()[unary_op_idx];
|
|
||||||
if (unary_op) {
|
|
||||||
return unary_op(ctx, tensor->src[0], tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto binary_op = qnn::ggml_qnn_binary_op_array()[tensor->op];
|
|
||||||
if (binary_op) {
|
|
||||||
return binary_op(ctx, tensor->src[0], tensor->src[1], tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
QNN_LOG_WARN("unsupported op %d", tensor->op);
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char *ggml_backend_qnn_buffer_get_name(ggml_backend_buffer_t buffer) {
|
static const char *ggml_backend_qnn_buffer_get_name(ggml_backend_buffer_t buffer) {
|
||||||
|
@ -288,42 +272,7 @@ GGML_CALL static ggml_status ggml_backend_qnn_graph_compute(ggml_backend_t backe
|
||||||
|
|
||||||
GGML_CALL static bool ggml_backend_qnn_supports_op(ggml_backend_t backend, const ggml_tensor *op) {
|
GGML_CALL static bool ggml_backend_qnn_supports_op(ggml_backend_t backend, const ggml_tensor *op) {
|
||||||
GGML_UNUSED(backend);
|
GGML_UNUSED(backend);
|
||||||
|
return qnn::ggml_qnn_supports_op(op);
|
||||||
if (op->op == GGML_OP_UNARY) {
|
|
||||||
if (!qnn::ggml_qnn_unary_op_array()[qnn::kGgmlUnaryOpStart + ggml_get_unary_op(op)]) {
|
|
||||||
QNN_LOG_DEBUG("unsupported unary op %d", ggml_get_unary_op(op));
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!op->src[0]) {
|
|
||||||
QNN_LOG_DEBUG("src0 is nullptr");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else if (op->op != GGML_OP_NONE) {
|
|
||||||
if (!qnn::ggml_qnn_unary_op_array()[op->op] && !qnn::ggml_qnn_binary_op_array()[op->op]) {
|
|
||||||
QNN_LOG_DEBUG("unsupported op %d", op->op);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!op->src[0] || !op->src[1]) {
|
|
||||||
QNN_LOG_DEBUG("src0 or src1 is nullptr");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (op->type) {
|
|
||||||
case GGML_TYPE_F32:
|
|
||||||
case GGML_TYPE_F16:
|
|
||||||
case GGML_TYPE_I8:
|
|
||||||
case GGML_TYPE_Q8_0:
|
|
||||||
case GGML_TYPE_Q4_0:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
QNN_LOG_DEBUG("unsupported src0 type %d", op->src[0]->type);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend, const ggml_tensor *op) {
|
GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend, const ggml_tensor *op) {
|
||||||
|
|
|
@ -56,6 +56,15 @@ bool qnn_is_valid_params(ggml_backend_qnn_context *ctx, const ggml_tensor *src0,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
typedef bool (*ggml_qnn_unary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_tensor *dst);
|
||||||
|
typedef bool (*ggml_qnn_binary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src0, ggml_tensor *src1,
|
||||||
|
ggml_tensor *dst);
|
||||||
|
|
||||||
|
typedef const ggml_qnn_unary_op_t (&ggml_qnn_unary_op_array_t)[GGML_OP_COUNT + GGML_UNARY_OP_COUNT];
|
||||||
|
typedef const ggml_qnn_binary_op_t (&ggml_qnn_binary_op_array_t)[GGML_OP_COUNT];
|
||||||
|
|
||||||
|
constexpr const size_t kGgmlUnaryOpStart = GGML_OP_COUNT;
|
||||||
|
|
||||||
void print_ggml_tensor(const ggml_tensor *tensor) {
|
void print_ggml_tensor(const ggml_tensor *tensor) {
|
||||||
QNN_LOG_DEBUG("%15s: type = %i (%5s) ne = %5" PRIi64 " x %5" PRIi64 " x %5" PRIi64 ", nb = (%5zi, %5zi, %5zi)\n",
|
QNN_LOG_DEBUG("%15s: type = %i (%5s) ne = %5" PRIi64 " x %5" PRIi64 " x %5" PRIi64 ", nb = (%5zi, %5zi, %5zi)\n",
|
||||||
tensor->name, tensor->type, ggml_type_name(tensor->type), tensor->ne[0], tensor->ne[1], tensor->ne[2],
|
tensor->name, tensor->type, ggml_type_name(tensor->type), tensor->ne[0], tensor->ne[1], tensor->ne[2],
|
||||||
|
@ -106,8 +115,8 @@ qnn::ggml_qnn_graph *get_qnn_graph_from_cache(ggml_backend_qnn_context *ctx, siz
|
||||||
GGML_ASSERT(op < (GGML_OP_COUNT + GGML_UNARY_OP_COUNT));
|
GGML_ASSERT(op < (GGML_OP_COUNT + GGML_UNARY_OP_COUNT));
|
||||||
|
|
||||||
auto &graph_cache = ctx->qnn_graph_cache;
|
auto &graph_cache = ctx->qnn_graph_cache;
|
||||||
const auto *op_name = op < qnn::kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op))
|
const auto *op_name =
|
||||||
: ggml_unary_op_name(ggml_unary_op(op - qnn::kGgmlUnaryOpStart));
|
op < kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op)) : ggml_unary_op_name(ggml_unary_op(op - kGgmlUnaryOpStart));
|
||||||
auto graph_key = get_graph_key<_InputSize, _OutputSize>(op_name, inputs, outputs);
|
auto graph_key = get_graph_key<_InputSize, _OutputSize>(op_name, inputs, outputs);
|
||||||
auto it = graph_cache.find(graph_key);
|
auto it = graph_cache.find(graph_key);
|
||||||
qnn::ggml_qnn_graph *graph_ptr = nullptr;
|
qnn::ggml_qnn_graph *graph_ptr = nullptr;
|
||||||
|
@ -237,7 +246,7 @@ constexpr const char *kGgmlOpToQnnOp[] = {
|
||||||
|
|
||||||
static_assert(sizeof(kGgmlOpToQnnOp) / sizeof(kGgmlOpToQnnOp[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT),
|
static_assert(sizeof(kGgmlOpToQnnOp) / sizeof(kGgmlOpToQnnOp[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT),
|
||||||
"GGML_OP_COUNT does not match the size of the kGgmlOpToQnnOp table");
|
"GGML_OP_COUNT does not match the size of the kGgmlOpToQnnOp table");
|
||||||
static_assert(kGgmlOpToQnnOp[GGML_UNARY_OP_GELU + qnn::kGgmlUnaryOpStart] != nullptr,
|
static_assert(kGgmlOpToQnnOp[GGML_UNARY_OP_GELU + kGgmlUnaryOpStart] != nullptr,
|
||||||
"GGML_UNARY_OP_GELU does not correspond to QNN_OP_GELU");
|
"GGML_UNARY_OP_GELU does not correspond to QNN_OP_GELU");
|
||||||
|
|
||||||
template <ggml_op _GgmlOp>
|
template <ggml_op _GgmlOp>
|
||||||
|
@ -281,10 +290,8 @@ bool qnn_unary_op_impl(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_ten
|
||||||
return succeed;
|
return succeed;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
ggml_qnn_unary_op_array_t ggml_qnn_unary_op_array() {
|
||||||
|
static constexpr const ggml_qnn_unary_op_t kQnnOpsTable[] = {
|
||||||
qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() {
|
|
||||||
static constexpr const qnn::ggml_qnn_unary_op_t kQnnOpsTable[] = {
|
|
||||||
nullptr, // GGML_OP_NONE
|
nullptr, // GGML_OP_NONE
|
||||||
nullptr, // GGML_OP_DUP
|
nullptr, // GGML_OP_DUP
|
||||||
nullptr, // GGML_OP_ADD
|
nullptr, // GGML_OP_ADD
|
||||||
|
@ -369,19 +376,19 @@ qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() {
|
||||||
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK
|
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK
|
||||||
|
|
||||||
// ggml_unary_op
|
// ggml_unary_op
|
||||||
nullptr, // GGML_UNARY_OP_ABS
|
nullptr, // GGML_UNARY_OP_ABS
|
||||||
nullptr, // GGML_UNARY_OP_SGN
|
nullptr, // GGML_UNARY_OP_SGN
|
||||||
nullptr, // GGML_UNARY_OP_NEG
|
nullptr, // GGML_UNARY_OP_NEG
|
||||||
nullptr, // GGML_UNARY_OP_STEP
|
nullptr, // GGML_UNARY_OP_STEP
|
||||||
nullptr, // GGML_UNARY_OP_TANH
|
nullptr, // GGML_UNARY_OP_TANH
|
||||||
nullptr, // GGML_UNARY_OP_ELU
|
nullptr, // GGML_UNARY_OP_ELU
|
||||||
nullptr, // GGML_UNARY_OP_RELU
|
nullptr, // GGML_UNARY_OP_RELU
|
||||||
nullptr, // GGML_UNARY_OP_SIGMOID
|
nullptr, // GGML_UNARY_OP_SIGMOID
|
||||||
qnn_unary_op_impl<GGML_UNARY_OP_GELU + qnn::kGgmlUnaryOpStart>, // GGML_UNARY_OP_GELU
|
qnn_unary_op_impl<GGML_UNARY_OP_GELU + kGgmlUnaryOpStart>, // GGML_UNARY_OP_GELU
|
||||||
nullptr, // GGML_UNARY_OP_GELU_QUICK
|
nullptr, // GGML_UNARY_OP_GELU_QUICK
|
||||||
nullptr, // GGML_UNARY_OP_SILU
|
nullptr, // GGML_UNARY_OP_SILU
|
||||||
nullptr, // GGML_UNARY_OP_HARDSWISH
|
nullptr, // GGML_UNARY_OP_HARDSWISH
|
||||||
nullptr, // GGML_UNARY_OP_HARDSIGMOID
|
nullptr, // GGML_UNARY_OP_HARDSIGMOID
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(sizeof(kQnnOpsTable) / sizeof(kQnnOpsTable[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT),
|
static_assert(sizeof(kQnnOpsTable) / sizeof(kQnnOpsTable[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT),
|
||||||
|
@ -389,8 +396,8 @@ qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() {
|
||||||
return kQnnOpsTable;
|
return kQnnOpsTable;
|
||||||
}
|
}
|
||||||
|
|
||||||
qnn::ggml_qnn_binary_op_array_t qnn::ggml_qnn_binary_op_array() {
|
ggml_qnn_binary_op_array_t ggml_qnn_binary_op_array() {
|
||||||
static constexpr const qnn::ggml_qnn_binary_op_t kQnnOpsTable[] = {
|
static constexpr const ggml_qnn_binary_op_t kQnnOpsTable[] = {
|
||||||
nullptr, // GGML_OP_NONE
|
nullptr, // GGML_OP_NONE
|
||||||
nullptr, // GGML_OP_DUP
|
nullptr, // GGML_OP_DUP
|
||||||
qnn_binary_op_impl<GGML_OP_ADD>, // GGML_OP_ADD
|
qnn_binary_op_impl<GGML_OP_ADD>, // GGML_OP_ADD
|
||||||
|
@ -479,3 +486,67 @@ qnn::ggml_qnn_binary_op_array_t qnn::ggml_qnn_binary_op_array() {
|
||||||
"GGML_OP_COUNT does not match the size of the ops table");
|
"GGML_OP_COUNT does not match the size of the ops table");
|
||||||
return kQnnOpsTable;
|
return kQnnOpsTable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace qnn {
|
||||||
|
|
||||||
|
bool ggml_qnn_supports_op(const ggml_tensor *op) {
|
||||||
|
if (op->op == GGML_OP_UNARY) {
|
||||||
|
if (!ggml_qnn_unary_op_array()[kGgmlUnaryOpStart + ggml_get_unary_op(op)]) {
|
||||||
|
QNN_LOG_DEBUG("unsupported unary op %d", ggml_get_unary_op(op));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!op->src[0]) {
|
||||||
|
QNN_LOG_DEBUG("src0 is nullptr");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (op->op != GGML_OP_NONE) {
|
||||||
|
if (!ggml_qnn_unary_op_array()[op->op] && !ggml_qnn_binary_op_array()[op->op]) {
|
||||||
|
QNN_LOG_DEBUG("unsupported op %d", op->op);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!op->src[0] || !op->src[1]) {
|
||||||
|
QNN_LOG_DEBUG("src0 or src1 is nullptr");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (op->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_I8:
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
QNN_LOG_DEBUG("unsupported src0 type %d", op->src[0]->type);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_qnn_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) {
|
||||||
|
size_t unary_op_idx = tensor->op;
|
||||||
|
if (tensor->op == GGML_OP_UNARY) {
|
||||||
|
unary_op_idx = kGgmlUnaryOpStart + ggml_get_unary_op(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto unary_op = ggml_qnn_unary_op_array()[unary_op_idx];
|
||||||
|
if (unary_op) {
|
||||||
|
return unary_op(ctx, tensor->src[0], tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto binary_op = ggml_qnn_binary_op_array()[tensor->op];
|
||||||
|
if (binary_op) {
|
||||||
|
return binary_op(ctx, tensor->src[0], tensor->src[1], tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
QNN_LOG_WARN("unsupported op %d", tensor->op);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace qnn
|
||||||
|
|
|
@ -6,16 +6,7 @@
|
||||||
|
|
||||||
namespace qnn {
|
namespace qnn {
|
||||||
|
|
||||||
typedef bool (*ggml_qnn_unary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_tensor *dst);
|
bool ggml_qnn_supports_op(const ggml_tensor *op);
|
||||||
typedef bool (*ggml_qnn_binary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src0, ggml_tensor *src1,
|
bool ggml_qnn_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor);
|
||||||
ggml_tensor *dst);
|
|
||||||
|
|
||||||
typedef const ggml_qnn_unary_op_t (&ggml_qnn_unary_op_array_t)[GGML_OP_COUNT + GGML_UNARY_OP_COUNT];
|
|
||||||
typedef const ggml_qnn_binary_op_t (&ggml_qnn_binary_op_array_t)[GGML_OP_COUNT];
|
|
||||||
|
|
||||||
constexpr const size_t kGgmlUnaryOpStart = GGML_OP_COUNT;
|
|
||||||
|
|
||||||
ggml_qnn_unary_op_array_t ggml_qnn_unary_op_array();
|
|
||||||
ggml_qnn_binary_op_array_t ggml_qnn_binary_op_array();
|
|
||||||
|
|
||||||
} // namespace qnn
|
} // namespace qnn
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue