add helper function for binary op
This commit is contained in:
parent
b6f29273f0
commit
7ea28a6fac
1 changed files with 116 additions and 34 deletions
|
@ -143,6 +143,116 @@ qnn::ggml_qnn_graph_binary *get_qnn_graph_from_cache(ggml_backend_qnn_context *c
|
|||
return graph_ptr;
|
||||
}
|
||||
|
||||
constexpr const char *kGgmlOpToQnnOp[] = {
|
||||
nullptr, // GGML_OP_NONE
|
||||
nullptr, // GGML_OP_DUP
|
||||
QNN_OP_ELEMENT_WISE_ADD, // GGML_OP_ADD
|
||||
nullptr, // GGML_OP_ADD1
|
||||
nullptr, // GGML_OP_ACC
|
||||
nullptr, // GGML_OP_SUB
|
||||
nullptr, // GGML_OP_MUL
|
||||
nullptr, // GGML_OP_DIV
|
||||
nullptr, // GGML_OP_SQR
|
||||
nullptr, // GGML_OP_SQRT
|
||||
nullptr, // GGML_OP_LOG
|
||||
nullptr, // GGML_OP_SUM
|
||||
nullptr, // GGML_OP_SUM_ROWS
|
||||
nullptr, // GGML_OP_MEAN
|
||||
nullptr, // GGML_OP_ARGMAX
|
||||
nullptr, // GGML_OP_REPEAT
|
||||
nullptr, // GGML_OP_REPEAT_BACK
|
||||
nullptr, // GGML_OP_CONCAT
|
||||
nullptr, // GGML_OP_SILU_BACK
|
||||
nullptr, // GGML_OP_NORM
|
||||
nullptr, // GGML_OP_RMS_NORM
|
||||
nullptr, // GGML_OP_RMS_NORM_BACK
|
||||
nullptr, // GGML_OP_GROUP_NORM
|
||||
|
||||
QNN_OP_MAT_MUL, // GGML_OP_MUL_MAT
|
||||
nullptr, // GGML_OP_MUL_MAT_ID
|
||||
nullptr, // GGML_OP_OUT_PROD
|
||||
|
||||
nullptr, // GGML_OP_SCALE
|
||||
nullptr, // GGML_OP_SET
|
||||
nullptr, // GGML_OP_CPY
|
||||
nullptr, // GGML_OP_CONT
|
||||
nullptr, // GGML_OP_RESHAPE
|
||||
nullptr, // GGML_OP_VIEW
|
||||
nullptr, // GGML_OP_PERMUTE
|
||||
nullptr, // GGML_OP_TRANSPOSE
|
||||
nullptr, // GGML_OP_GET_ROWS
|
||||
nullptr, // GGML_OP_GET_ROWS_BACK
|
||||
nullptr, // GGML_OP_DIAG
|
||||
nullptr, // GGML_OP_DIAG_MASK_INF
|
||||
nullptr, // GGML_OP_DIAG_MASK_ZERO
|
||||
nullptr, // GGML_OP_SOFT_MAX
|
||||
nullptr, // GGML_OP_SOFT_MAX_BACK
|
||||
nullptr, // GGML_OP_ROPE
|
||||
nullptr, // GGML_OP_ROPE_BACK
|
||||
nullptr, // GGML_OP_CLAMP
|
||||
nullptr, // GGML_OP_CONV_TRANSPOSE_1D
|
||||
nullptr, // GGML_OP_IM2COL
|
||||
nullptr, // GGML_OP_CONV_TRANSPOSE_2D
|
||||
nullptr, // GGML_OP_POOL_1D
|
||||
nullptr, // GGML_OP_POOL_2D
|
||||
nullptr, // GGML_OP_UPSCALE
|
||||
nullptr, // GGML_OP_PAD
|
||||
nullptr, // GGML_OP_ARANGE
|
||||
nullptr, // GGML_OP_TIMESTEP_EMBEDDING
|
||||
nullptr, // GGML_OP_ARGSORT
|
||||
nullptr, // GGML_OP_LEAKY_RELU
|
||||
|
||||
nullptr, // GGML_OP_FLASH_ATTN_EXT
|
||||
nullptr, // GGML_OP_FLASH_ATTN_BACK
|
||||
nullptr, // GGML_OP_SSM_CONV
|
||||
nullptr, // GGML_OP_SSM_SCAN
|
||||
nullptr, // GGML_OP_WIN_PART
|
||||
nullptr, // GGML_OP_WIN_UNPART
|
||||
nullptr, // GGML_OP_GET_REL_POS
|
||||
nullptr, // GGML_OP_ADD_REL_POS
|
||||
|
||||
nullptr, // GGML_OP_UNARY
|
||||
|
||||
nullptr, // GGML_OP_MAP_UNARY
|
||||
nullptr, // GGML_OP_MAP_BINARY
|
||||
|
||||
nullptr, // GGML_OP_MAP_CUSTOM1_F32
|
||||
nullptr, // GGML_OP_MAP_CUSTOM2_F32
|
||||
nullptr, // GGML_OP_MAP_CUSTOM3_F32
|
||||
|
||||
nullptr, // GGML_OP_MAP_CUSTOM1
|
||||
nullptr, // GGML_OP_MAP_CUSTOM2
|
||||
nullptr, // GGML_OP_MAP_CUSTOM3
|
||||
|
||||
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS
|
||||
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK
|
||||
};
|
||||
|
||||
static_assert(sizeof(kGgmlOpToQnnOp) / sizeof(kGgmlOpToQnnOp[0]) == GGML_OP_COUNT,
|
||||
"GGML_OP_COUNT does not match the size of the ops table");
|
||||
|
||||
template <ggml_op _GgmlOp>
|
||||
void qnn_binary_op_impl(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst) {
|
||||
static_assert(kGgmlOpToQnnOp[_GgmlOp] != nullptr, "GGML_OP does not have a corresponding QNN_OP");
|
||||
|
||||
qnn::qnn_perf perf(ggml_op_name(_GgmlOp));
|
||||
perf.start();
|
||||
|
||||
bool succeed = false;
|
||||
qnn::ggml_qnn_graph_binary *graph_ptr =
|
||||
get_qnn_graph_from_cache<2, 1>(ctx, _GgmlOp, kGgmlOpToQnnOp[_GgmlOp], { src0, src1 }, { dst });
|
||||
if (graph_ptr) {
|
||||
succeed = execute_graph<2, 1>(graph_ptr, { src0, src1 }, { dst });
|
||||
}
|
||||
|
||||
if (!succeed) {
|
||||
print_ggml_tensor(src0);
|
||||
print_ggml_tensor(src1);
|
||||
print_ggml_tensor(dst);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#ifndef NDEBUG
|
||||
|
@ -160,22 +270,7 @@ qnn::ggml_qnn_graph_binary *get_qnn_graph_from_cache(ggml_backend_qnn_context *c
|
|||
static void ggml_qnn_add(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst) {
|
||||
CHECK_PARAMS(ctx, src0, src1, dst);
|
||||
|
||||
qnn::qnn_perf perf("ggml_op_qnn_add");
|
||||
perf.start();
|
||||
|
||||
bool succeed = false;
|
||||
qnn::ggml_qnn_graph_binary *graph_ptr =
|
||||
get_qnn_graph_from_cache<2, 1>(ctx, GGML_OP_ADD, QNN_OP_ELEMENT_WISE_ADD, { src0, src1 }, { dst });
|
||||
if (graph_ptr) {
|
||||
succeed = execute_graph<2, 1>(graph_ptr, { src0, src1 }, { dst });
|
||||
}
|
||||
|
||||
if (!succeed) {
|
||||
print_ggml_tensor(src0);
|
||||
print_ggml_tensor(src1);
|
||||
print_ggml_tensor(dst);
|
||||
}
|
||||
qnn_binary_op_impl<GGML_OP_ADD>(ctx, src0, src1, dst);
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -192,22 +287,7 @@ static void ggml_qnn_add(ggml_backend_qnn_context *ctx, const ggml_tensor *src0,
|
|||
static void ggml_qnn_mul_mat(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst) {
|
||||
CHECK_PARAMS(ctx, src0, src1, dst);
|
||||
|
||||
qnn::qnn_perf perf("ggml_op_qnn_mul_mat");
|
||||
perf.start();
|
||||
|
||||
bool succeed = false;
|
||||
qnn::ggml_qnn_graph_binary *graph_ptr =
|
||||
get_qnn_graph_from_cache<2, 1>(ctx, GGML_OP_MUL_MAT, QNN_OP_MAT_MUL, { src0, src1 }, { dst });
|
||||
if (graph_ptr) {
|
||||
succeed = execute_graph<2, 1>(graph_ptr, { src0, src1 }, { dst });
|
||||
}
|
||||
|
||||
if (!succeed) {
|
||||
print_ggml_tensor(src0);
|
||||
print_ggml_tensor(src1);
|
||||
print_ggml_tensor(dst);
|
||||
}
|
||||
qnn_binary_op_impl<GGML_OP_MUL_MAT>(ctx, src0, src1, dst);
|
||||
}
|
||||
|
||||
static void ggml_qnn_repeat(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
|
@ -320,7 +400,7 @@ static void ggml_qnn_nop(ggml_backend_qnn_context *ctx, const ggml_tensor *src0,
|
|||
}
|
||||
|
||||
qnn::ggml_qnn_op_array_t qnn::ggml_qnn_op_array() {
|
||||
static constexpr const qnn::ggml_qnn_op_t kQnnOpsTable[GGML_OP_COUNT] = {
|
||||
static constexpr const qnn::ggml_qnn_op_t kQnnOpsTable[] = {
|
||||
nullptr, // GGML_OP_NONE
|
||||
nullptr, // GGML_OP_DUP
|
||||
ggml_qnn_add, // GGML_OP_ADD
|
||||
|
@ -405,5 +485,7 @@ qnn::ggml_qnn_op_array_t qnn::ggml_qnn_op_array() {
|
|||
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK
|
||||
};
|
||||
|
||||
static_assert(sizeof(kQnnOpsTable) / sizeof(kQnnOpsTable[0]) == GGML_OP_COUNT,
|
||||
"GGML_OP_COUNT does not match the size of the ops table");
|
||||
return kQnnOpsTable;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue