feat: add GGML_UNARY_OP_GELU
This commit is contained in:
parent
ce199b2de7
commit
d82b3a0bdb
3 changed files with 60 additions and 25 deletions
|
@ -135,14 +135,19 @@ struct ggml_backend_qnn_buffer_type_context {
|
||||||
//
|
//
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
static bool ggml_qnn_can_handle_op(ggml_backend_qnn_context *ctx, const struct ggml_tensor *tensor) {
|
static bool ggml_qnn_can_handle_op(ggml_backend_qnn_context *ctx, const struct ggml_tensor *tensor) {
|
||||||
if (ggml_is_empty(tensor) ||
|
if (ggml_is_empty(tensor)) {
|
||||||
(!qnn::ggml_qnn_unary_op_array()[tensor->op] && !qnn::ggml_qnn_binary_op_array()[tensor->op])) {
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!qnn::ggml_qnn_unary_op_array()[tensor->op] && !qnn::ggml_qnn_binary_op_array()[tensor->op] &&
|
||||||
|
(tensor->op != GGML_OP_UNARY ||
|
||||||
|
qnn::ggml_qnn_unary_op_array()[qnn::kGgmlUnaryOpStart + ggml_get_unary_op(tensor)])) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const struct ggml_tensor *src0 = tensor->src[0];
|
const struct ggml_tensor *src0 = tensor->src[0];
|
||||||
const struct ggml_tensor *src1 = tensor->src[1];
|
const struct ggml_tensor *src1 = tensor->src[1];
|
||||||
if (nullptr == src0 || nullptr == src1) {
|
if (!src0 || !src1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,18 +167,16 @@ static bool ggml_qnn_can_handle_op(ggml_backend_qnn_context *ctx, const struct g
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tensor->op == GGML_OP_MUL_MAT) {
|
|
||||||
if (ne00 <= 32 || ne01 <= 32 || ne10 <= 32 || ne11 <= 32) {
|
|
||||||
// comment it for make UT of mul_mat with QNN RPC happy
|
|
||||||
// return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_qnn_compute_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) {
|
static bool ggml_qnn_compute_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) {
|
||||||
auto unary_op = qnn::ggml_qnn_unary_op_array()[tensor->op];
|
size_t unary_op_idx = tensor->op;
|
||||||
|
if (tensor->op == GGML_OP_UNARY) {
|
||||||
|
unary_op_idx = qnn::kGgmlUnaryOpStart + ggml_get_unary_op(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto unary_op = qnn::ggml_qnn_unary_op_array()[unary_op_idx];
|
||||||
if (unary_op) {
|
if (unary_op) {
|
||||||
return unary_op(ctx, tensor->src[0], tensor);
|
return unary_op(ctx, tensor->src[0], tensor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -158,12 +158,16 @@ qnn::ggml_qnn_binary_graph_cache_t &get_qnn_graph_cache(ggml_backend_qnn_context
|
||||||
|
|
||||||
template <size_t _InputSize, size_t _OutputSize>
|
template <size_t _InputSize, size_t _OutputSize>
|
||||||
qnn::ggml_qnn_graph<_InputSize, _OutputSize> *get_qnn_graph_from_cache(
|
qnn::ggml_qnn_graph<_InputSize, _OutputSize> *get_qnn_graph_from_cache(
|
||||||
ggml_backend_qnn_context *ctx, ggml_op op, const std::string &qnn_op,
|
ggml_backend_qnn_context *ctx, size_t op, const std::string &qnn_op,
|
||||||
const std::array<const ggml_tensor *, _InputSize> &inputs, const std::array<ggml_tensor *, _OutputSize> &outputs) {
|
const std::array<const ggml_tensor *, _InputSize> &inputs, const std::array<ggml_tensor *, _OutputSize> &outputs) {
|
||||||
using graph_t = qnn::ggml_qnn_graph<_InputSize, _OutputSize>;
|
using graph_t = qnn::ggml_qnn_graph<_InputSize, _OutputSize>;
|
||||||
|
|
||||||
|
GGML_ASSERT(op < (GGML_OP_COUNT + GGML_UNARY_OP_COUNT));
|
||||||
|
|
||||||
auto &graph_cache = get_qnn_graph_cache(ctx, inputs, outputs);
|
auto &graph_cache = get_qnn_graph_cache(ctx, inputs, outputs);
|
||||||
const std::string graph_key(ggml_op_name(op));
|
const auto *op_name = op < qnn::kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op))
|
||||||
|
: ggml_unary_op_name(ggml_unary_op(op - qnn::kGgmlUnaryOpStart));
|
||||||
|
const std::string graph_key(op_name);
|
||||||
auto it = graph_cache.find(graph_key);
|
auto it = graph_cache.find(graph_key);
|
||||||
graph_t *graph_ptr = nullptr;
|
graph_t *graph_ptr = nullptr;
|
||||||
if (it != graph_cache.end()) {
|
if (it != graph_cache.end()) {
|
||||||
|
@ -276,10 +280,27 @@ constexpr const char *kGgmlOpToQnnOp[] = {
|
||||||
|
|
||||||
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS
|
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS
|
||||||
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK
|
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK
|
||||||
|
|
||||||
|
// ggml_unary_op
|
||||||
|
nullptr, // GGML_UNARY_OP_ABS
|
||||||
|
nullptr, // GGML_UNARY_OP_SGN
|
||||||
|
nullptr, // GGML_UNARY_OP_NEG
|
||||||
|
nullptr, // GGML_UNARY_OP_STEP
|
||||||
|
nullptr, // GGML_UNARY_OP_TANH
|
||||||
|
nullptr, // GGML_UNARY_OP_ELU
|
||||||
|
nullptr, // GGML_UNARY_OP_RELU
|
||||||
|
nullptr, // GGML_UNARY_OP_SIGMOID
|
||||||
|
QNN_OP_GELU, // GGML_UNARY_OP_GELU
|
||||||
|
nullptr, // GGML_UNARY_OP_GELU_QUICK
|
||||||
|
nullptr, // GGML_UNARY_OP_SILU
|
||||||
|
nullptr, // GGML_UNARY_OP_HARDSWISH
|
||||||
|
nullptr, // GGML_UNARY_OP_HARDSIGMOID
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(sizeof(kGgmlOpToQnnOp) / sizeof(kGgmlOpToQnnOp[0]) == GGML_OP_COUNT,
|
static_assert(sizeof(kGgmlOpToQnnOp) / sizeof(kGgmlOpToQnnOp[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT),
|
||||||
"GGML_OP_COUNT does not match the size of the ops table");
|
"GGML_OP_COUNT does not match the size of the kGgmlOpToQnnOp table");
|
||||||
|
static_assert(kGgmlOpToQnnOp[GGML_UNARY_OP_GELU + qnn::kGgmlUnaryOpStart] != nullptr,
|
||||||
|
"GGML_UNARY_OP_GELU does not correspond to QNN_OP_GELU");
|
||||||
|
|
||||||
template <ggml_op _GgmlOp>
|
template <ggml_op _GgmlOp>
|
||||||
bool qnn_binary_op_impl(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
bool qnn_binary_op_impl(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
|
@ -288,9 +309,6 @@ bool qnn_binary_op_impl(ggml_backend_qnn_context *ctx, const ggml_tensor *src0,
|
||||||
|
|
||||||
CHECK_PARAMS(ctx, src0, src1, dst);
|
CHECK_PARAMS(ctx, src0, src1, dst);
|
||||||
|
|
||||||
qnn::qnn_perf perf(ggml_op_name(_GgmlOp));
|
|
||||||
perf.start();
|
|
||||||
|
|
||||||
bool succeed = false;
|
bool succeed = false;
|
||||||
qnn::ggml_qnn_graph_binary *graph_ptr =
|
qnn::ggml_qnn_graph_binary *graph_ptr =
|
||||||
get_qnn_graph_from_cache<2, 1>(ctx, _GgmlOp, kGgmlOpToQnnOp[_GgmlOp], { src0, src1 }, { dst });
|
get_qnn_graph_from_cache<2, 1>(ctx, _GgmlOp, kGgmlOpToQnnOp[_GgmlOp], { src0, src1 }, { dst });
|
||||||
|
@ -307,15 +325,12 @@ bool qnn_binary_op_impl(ggml_backend_qnn_context *ctx, const ggml_tensor *src0,
|
||||||
return succeed;
|
return succeed;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <ggml_op _GgmlOp>
|
template <size_t _GgmlOp>
|
||||||
bool qnn_unary_op_impl(ggml_backend_qnn_context *ctx, const ggml_tensor *src, ggml_tensor *dst) {
|
bool qnn_unary_op_impl(ggml_backend_qnn_context *ctx, const ggml_tensor *src, ggml_tensor *dst) {
|
||||||
static_assert(kGgmlOpToQnnOp[_GgmlOp] != nullptr, "GGML_OP does not have a corresponding QNN_OP");
|
static_assert(kGgmlOpToQnnOp[_GgmlOp] != nullptr, "GGML_OP does not have a corresponding QNN_OP");
|
||||||
|
|
||||||
CHECK_PARAMS(ctx, src, dst);
|
CHECK_PARAMS(ctx, src, dst);
|
||||||
|
|
||||||
qnn::qnn_perf perf(ggml_op_name(_GgmlOp));
|
|
||||||
perf.start();
|
|
||||||
|
|
||||||
bool succeed = false;
|
bool succeed = false;
|
||||||
auto *graph_ptr = get_qnn_graph_from_cache<1, 1>(ctx, _GgmlOp, kGgmlOpToQnnOp[_GgmlOp], { src }, { dst });
|
auto *graph_ptr = get_qnn_graph_from_cache<1, 1>(ctx, _GgmlOp, kGgmlOpToQnnOp[_GgmlOp], { src }, { dst });
|
||||||
if (graph_ptr) {
|
if (graph_ptr) {
|
||||||
|
@ -416,10 +431,25 @@ qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() {
|
||||||
|
|
||||||
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS
|
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS
|
||||||
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK
|
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK
|
||||||
|
|
||||||
|
// ggml_unary_op
|
||||||
|
nullptr, // GGML_UNARY_OP_ABS
|
||||||
|
nullptr, // GGML_UNARY_OP_SGN
|
||||||
|
nullptr, // GGML_UNARY_OP_NEG
|
||||||
|
nullptr, // GGML_UNARY_OP_STEP
|
||||||
|
nullptr, // GGML_UNARY_OP_TANH
|
||||||
|
nullptr, // GGML_UNARY_OP_ELU
|
||||||
|
nullptr, // GGML_UNARY_OP_RELU
|
||||||
|
nullptr, // GGML_UNARY_OP_SIGMOID
|
||||||
|
qnn_unary_op_impl<GGML_UNARY_OP_GELU + qnn::kGgmlUnaryOpStart>, // GGML_UNARY_OP_GELU
|
||||||
|
nullptr, // GGML_UNARY_OP_GELU_QUICK
|
||||||
|
nullptr, // GGML_UNARY_OP_SILU
|
||||||
|
nullptr, // GGML_UNARY_OP_HARDSWISH
|
||||||
|
nullptr, // GGML_UNARY_OP_HARDSIGMOID
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(sizeof(kQnnOpsTable) / sizeof(kQnnOpsTable[0]) == GGML_OP_COUNT,
|
static_assert(sizeof(kQnnOpsTable) / sizeof(kQnnOpsTable[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT),
|
||||||
"GGML_OP_COUNT does not match the size of the ops table");
|
"GGML_OP_COUNT does not match the size of the kQnnOpsTable table");
|
||||||
return kQnnOpsTable;
|
return kQnnOpsTable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,11 @@ typedef bool (*ggml_qnn_unary_op_t)(ggml_backend_qnn_context *ctx, const ggml_te
|
||||||
typedef bool (*ggml_qnn_binary_op_t)(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
typedef bool (*ggml_qnn_binary_op_t)(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
ggml_tensor *dst);
|
ggml_tensor *dst);
|
||||||
|
|
||||||
typedef const ggml_qnn_unary_op_t (&ggml_qnn_unary_op_array_t)[GGML_OP_COUNT];
|
typedef const ggml_qnn_unary_op_t (&ggml_qnn_unary_op_array_t)[GGML_OP_COUNT + GGML_UNARY_OP_COUNT];
|
||||||
typedef const ggml_qnn_binary_op_t (&ggml_qnn_binary_op_array_t)[GGML_OP_COUNT];
|
typedef const ggml_qnn_binary_op_t (&ggml_qnn_binary_op_array_t)[GGML_OP_COUNT];
|
||||||
|
|
||||||
|
constexpr const size_t kGgmlUnaryOpStart = GGML_OP_COUNT;
|
||||||
|
|
||||||
ggml_qnn_unary_op_array_t ggml_qnn_unary_op_array();
|
ggml_qnn_unary_op_array_t ggml_qnn_unary_op_array();
|
||||||
ggml_qnn_binary_op_array_t ggml_qnn_binary_op_array();
|
ggml_qnn_binary_op_array_t ggml_qnn_binary_op_array();
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue