use table to simpilify the op mapping

This commit is contained in:
hongruichen 2024-07-12 19:52:35 +08:00
parent f0894d897a
commit 0eb595cc6e

View file

@ -327,6 +327,41 @@ static void show_usage() {
);
}
typedef ggml_tensor * (*ggml_op_binary_t)(
ggml_context * ctx,
ggml_tensor * a,
ggml_tensor * b);
static constexpr const ggml_op_binary_t kBinaryOps[] = {
nullptr, // GGML_OP_NONE
nullptr, // GGML_OP_DUP
ggml_add, // GGML_OP_ADD
nullptr, // GGML_OP_ADD1
nullptr, // GGML_OP_ACC
nullptr, // GGML_OP_SUB
ggml_mul, // GGML_OP_MUL
nullptr, // GGML_OP_DIV
nullptr, // GGML_OP_SQR
nullptr, // GGML_OP_SQRT
nullptr, // GGML_OP_LOG
nullptr, // GGML_OP_SUM
nullptr, // GGML_OP_SUM_ROWS
nullptr, // GGML_OP_MEAN
nullptr, // GGML_OP_ARGMAX
nullptr, // GGML_OP_REPEAT
nullptr, // GGML_OP_REPEAT_BACK
nullptr, // GGML_OP_CONCAT
nullptr, // GGML_OP_SILU_BACK
nullptr, // GGML_OP_NORM
nullptr, // GGML_OP_RMS_NORM
nullptr, // GGML_OP_RMS_NORM_BACK
nullptr, // GGML_OP_GROUP_NORM
ggml_mul_mat, // GGML_OP_MUL_MAT
};
static_assert(kBinaryOps[GGML_OP_MUL_MAT] == ggml_mul_mat, "ggml_mul_mat at wrong index, check kBinaryOps");
static int qnn_op_ut(int num_threads, int n_backend_type, int n_ggml_op_type) {
int64_t n_begin_time = 0LL;
int64_t n_end_time = 0LL;
@ -398,19 +433,15 @@ static int qnn_op_ut(int num_threads, int n_backend_type, int n_ggml_op_type) {
ggml_set_input(src0);
ggml_set_input(src1);
switch (n_ggml_op_type) {
case GGML_OP_ADD:
dst = ggml_add(ctx, src0, src1);
break;
case GGML_OP_MUL_MAT:
dst = ggml_mul_mat(ctx, src0, src1);
break;
default:
QNN_LOG_WARN("ggml op %d(%s) not supported", n_ggml_op_type,
ggml_op_name((enum ggml_op) n_ggml_op_type));
ggml_free(ctx);
ggml_backend_free(backend);
return 3;
auto binary_op = kBinaryOps[n_ggml_op_type];
if (binary_op) {
dst = binary_op(ctx, src0, src1);
} else {
QNN_LOG_WARN("ggml op %d(%s) not supported", n_ggml_op_type,
ggml_op_name((enum ggml_op) n_ggml_op_type));
ggml_free(ctx);
ggml_backend_free(backend);
return 3;
}
ggml_set_output(dst);
@ -473,6 +504,11 @@ static int qnn_op_ut(int num_threads, int n_backend_type, int n_ggml_op_type) {
return 0;
}
static const std::unordered_map<std::string, int> kMapStringToGGMLOp = {
{"GGML_OP_ADD", GGML_OP_ADD},
{"GGML_OP_MUL_MAT", GGML_OP_MUL_MAT},
};
int main(int argc, char * argv[]) {
int num_threads = 4;
int n_backend_type = QNN_BACKEND_CPU;
@ -481,10 +517,9 @@ int main(int argc, char * argv[]) {
for (int i = 1; i < argc; i++) {
if (0 == strcmp(argv[i], "-t")) {
if (i + 1 < argc) {
if (0 == memcmp(argv[i + 1], "GGML_OP_ADD", 11)) {
n_ggml_op_type = GGML_OP_ADD;
} else if (0 == memcmp(argv[i + 1], "GGML_OP_MUL_MAT", 15)) {
n_ggml_op_type = GGML_OP_MUL_MAT;
auto it = kMapStringToGGMLOp.find(argv[i + 1]);
if (it != kMapStringToGGMLOp.end()) {
n_ggml_op_type = it->second;
} else {
show_usage();
return 1;