expose op at unit test

This commit is contained in:
hongruichen 2024-07-13 10:55:36 +08:00
parent 100ccd5e7f
commit c1e2283887

View file

@ -327,21 +327,51 @@ static void show_usage() {
); );
} }
typedef ggml_tensor * (*ggml_op_unary_t)(
ggml_context * ctx,
ggml_tensor * a);
typedef ggml_tensor * (*ggml_op_binary_t)( typedef ggml_tensor * (*ggml_op_binary_t)(
ggml_context * ctx, ggml_context * ctx,
ggml_tensor * a, ggml_tensor * a,
ggml_tensor * b); ggml_tensor * b);
static constexpr const ggml_op_unary_t kUnaryOps[] = {
nullptr, // GGML_OP_NONE
nullptr, // GGML_OP_DUP
nullptr, // GGML_OP_ADD
nullptr, // GGML_OP_ADD1
nullptr, // GGML_OP_ACC
nullptr, // GGML_OP_SUB
nullptr, // GGML_OP_MUL
nullptr, // GGML_OP_DIV
nullptr, // GGML_OP_SQR
ggml_sqrt, // GGML_OP_SQRT
nullptr, // GGML_OP_LOG
nullptr, // GGML_OP_SUM
nullptr, // GGML_OP_SUM_ROWS
nullptr, // GGML_OP_MEAN
nullptr, // GGML_OP_ARGMAX
nullptr, // GGML_OP_REPEAT
nullptr, // GGML_OP_REPEAT_BACK
nullptr, // GGML_OP_CONCAT
nullptr, // GGML_OP_SILU_BACK
nullptr, // GGML_OP_NORM
nullptr, // GGML_OP_RMS_NORM
nullptr, // GGML_OP_RMS_NORM_BACK
nullptr, // GGML_OP_GROUP_NORM
nullptr, // GGML_OP_MUL_MAT
};
static constexpr const ggml_op_binary_t kBinaryOps[] = { static constexpr const ggml_op_binary_t kBinaryOps[] = {
nullptr, // GGML_OP_NONE nullptr, // GGML_OP_NONE
nullptr, // GGML_OP_DUP nullptr, // GGML_OP_DUP
ggml_add, // GGML_OP_ADD ggml_add, // GGML_OP_ADD
nullptr, // GGML_OP_ADD1 nullptr, // GGML_OP_ADD1
nullptr, // GGML_OP_ACC nullptr, // GGML_OP_ACC
nullptr, // GGML_OP_SUB ggml_sub, // GGML_OP_SUB
ggml_mul, // GGML_OP_MUL ggml_mul, // GGML_OP_MUL
nullptr, // GGML_OP_DIV ggml_div, // GGML_OP_DIV
nullptr, // GGML_OP_SQR nullptr, // GGML_OP_SQR
nullptr, // GGML_OP_SQRT nullptr, // GGML_OP_SQRT
nullptr, // GGML_OP_LOG nullptr, // GGML_OP_LOG
@ -433,8 +463,11 @@ static int qnn_op_ut(int num_threads, int n_backend_type, int n_ggml_op_type) {
ggml_set_input(src0); ggml_set_input(src0);
ggml_set_input(src1); ggml_set_input(src1);
auto unary_op = kUnaryOps[n_ggml_op_type];
auto binary_op = kBinaryOps[n_ggml_op_type]; auto binary_op = kBinaryOps[n_ggml_op_type];
if (binary_op) { if (unary_op) {
dst = unary_op(ctx, src0);
} else if (binary_op) {
dst = binary_op(ctx, src0, src1); dst = binary_op(ctx, src0, src1);
} else { } else {
QNN_LOG_WARN("ggml op %d(%s) not supported", n_ggml_op_type, QNN_LOG_WARN("ggml op %d(%s) not supported", n_ggml_op_type,
@ -504,10 +537,15 @@ static int qnn_op_ut(int num_threads, int n_backend_type, int n_ggml_op_type) {
return 0; return 0;
} }
#define DEFINE_OP(op) { #op, op }
static const std::unordered_map<std::string, int> kMapStringToGGMLOp = { static const std::unordered_map<std::string, int> kMapStringToGGMLOp = {
{"GGML_OP_ADD", GGML_OP_ADD}, DEFINE_OP(GGML_OP_ADD),
{"GGML_OP_MUL_MAT", GGML_OP_MUL_MAT}, DEFINE_OP(GGML_OP_SUB),
{"GGML_OP_MUL", GGML_OP_MUL}, DEFINE_OP(GGML_OP_MUL),
DEFINE_OP(GGML_OP_DIV),
DEFINE_OP(GGML_OP_SQRT),
DEFINE_OP(GGML_OP_MUL_MAT),
}; };
int main(int argc, char * argv[]) { int main(int argc, char * argv[]) {