From e1056da1c083ecba7b10d4833963a7c429cee054 Mon Sep 17 00:00:00 2001 From: hongruichen Date: Mon, 24 Jun 2024 12:06:42 +0800 Subject: [PATCH] fix op handle checker --- ggml-qnn.cpp | 231 ++++++++++++++++++++++----------------------------- 1 file changed, 100 insertions(+), 131 deletions(-) diff --git a/ggml-qnn.cpp b/ggml-qnn.cpp index 3a667a197..ffa437184 100644 --- a/ggml-qnn.cpp +++ b/ggml-qnn.cpp @@ -354,12 +354,100 @@ static int free_qnn_tensor(Qnn_Tensor_t & tensor) { // implementation of QNN backend for GGML // // ================================================================================================= +static void ggml_qnn_add(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst); +static void ggml_qnn_mul_mat(ggml_backend_qnn_context * ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +static ggml_qnn_func_t s_op_table[GGML_OP_COUNT] = { + nullptr, // GGML_OP_NONE + nullptr, // GGML_OP_DUP + ggml_qnn_add, // GGML_OP_ADD + nullptr, // GGML_OP_ADD1 + nullptr, // GGML_OP_ACC + nullptr, // GGML_OP_SUB + nullptr, // GGML_OP_MUL + nullptr, // GGML_OP_DIV + nullptr, // GGML_OP_SQR + nullptr, // GGML_OP_SQRT + nullptr, // GGML_OP_LOG + nullptr, // GGML_OP_SUM + nullptr, // GGML_OP_SUM_ROWS + nullptr, // GGML_OP_MEAN + nullptr, // GGML_OP_ARGMAX + nullptr, // GGML_OP_REPEAT + nullptr, // GGML_OP_REPEAT_BACK + nullptr, // GGML_OP_CONCAT + nullptr, // GGML_OP_SILU_BACK + nullptr, // GGML_OP_NORM + nullptr, // GGML_OP_RMS_NORM + nullptr, // GGML_OP_RMS_NORM_BACK + nullptr, // GGML_OP_GROUP_NORM + + ggml_qnn_mul_mat, // GGML_OP_MUL_MAT + nullptr, // GGML_OP_MUL_MAT_ID + nullptr, // GGML_OP_OUT_PROD + + nullptr, // GGML_OP_SCALE + nullptr, // GGML_OP_SET + nullptr, // GGML_OP_CPY + nullptr, // GGML_OP_CONT + nullptr, // GGML_OP_RESHAPE + nullptr, // GGML_OP_VIEW + nullptr, // GGML_OP_PERMUTE + nullptr, // GGML_OP_TRANSPOSE + nullptr, // GGML_OP_GET_ROWS + nullptr, // GGML_OP_GET_ROWS_BACK + nullptr, // GGML_OP_DIAG + nullptr, // GGML_OP_DIAG_MASK_INF + nullptr, // GGML_OP_DIAG_MASK_ZERO + nullptr, // GGML_OP_SOFT_MAX + nullptr, // GGML_OP_SOFT_MAX_BACK + nullptr, // GGML_OP_ROPE + nullptr, // GGML_OP_ROPE_BACK + nullptr, // GGML_OP_CLAMP + nullptr, // GGML_OP_CONV_TRANSPOSE_1D + nullptr, // GGML_OP_IM2COL + nullptr, // GGML_OP_CONV_TRANSPOSE_2D + nullptr, // GGML_OP_POOL_1D + nullptr, // GGML_OP_POOL_2D + nullptr, // GGML_OP_UPSCALE + nullptr, // GGML_OP_PAD + nullptr, // GGML_OP_ARANGE + nullptr, // GGML_OP_TIMESTEP_EMBEDDING + nullptr, // GGML_OP_ARGSORT + nullptr, // GGML_OP_LEAKY_RELU + + nullptr, // GGML_OP_FLASH_ATTN_EXT + nullptr, // GGML_OP_FLASH_ATTN_BACK + nullptr, // GGML_OP_SSM_CONV + nullptr, // GGML_OP_SSM_SCAN + nullptr, // GGML_OP_WIN_PART + nullptr, // GGML_OP_WIN_UNPART + nullptr, // GGML_OP_GET_REL_POS + nullptr, // GGML_OP_ADD_REL_POS + + nullptr, // GGML_OP_UNARY + + nullptr, // GGML_OP_MAP_UNARY + nullptr, // GGML_OP_MAP_BINARY + + nullptr, // GGML_OP_MAP_CUSTOM1_F32 + nullptr, // GGML_OP_MAP_CUSTOM2_F32 + nullptr, // GGML_OP_MAP_CUSTOM3_F32 + + nullptr, // GGML_OP_MAP_CUSTOM1 + nullptr, // GGML_OP_MAP_CUSTOM2 + nullptr, // GGML_OP_MAP_CUSTOM3 + + nullptr, // GGML_OP_CROSS_ENTROPY_LOSS + nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK +}; + static bool ggml_qnn_can_handle_op(ggml_backend_qnn_context * ctx, const struct ggml_tensor * tensor, bool b_dump_tensor_info) { - if (ggml_is_empty(tensor) || tensor->op == GGML_OP_RESHAPE || - tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_VIEW || - tensor->op == GGML_OP_PERMUTE || tensor->op == GGML_OP_NONE) { + if (ggml_is_empty(tensor) || !s_op_table[tensor->op]) { return false; } @@ -369,10 +457,10 @@ static bool ggml_qnn_can_handle_op(ggml_backend_qnn_context * ctx, return false; } - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; + const auto ne00 = src0->ne[0]; + const auto ne01 = src0->ne[1]; + const auto ne10 = src1->ne[0]; + const auto ne11 = src1->ne[1]; // make qnn_get_ggml_tensor_rank and QNN SDK happy if (ne00 <= 1 || ne01 <= 1 || ne10 <= 1 || ne11 <= 1) { return false; @@ -951,132 +1039,13 @@ static void ggml_qnn_nop(ggml_backend_qnn_context * ctx, const ggml_tensor * src bool ggml_qnn_compute_forward(ggml_backend_qnn_context * ctx, struct ggml_compute_params * params, struct ggml_tensor * tensor) { - ggml_qnn_func_t func = nullptr; - - switch (tensor->op) { - case GGML_OP_ADD: - func = ggml_qnn_add; - break; - case GGML_OP_MUL_MAT: - func = ggml_qnn_mul_mat; - break; - case GGML_OP_REPEAT: - func = ggml_qnn_repeat; - break; - case GGML_OP_GET_ROWS: - func = ggml_qnn_get_rows; - break; - case GGML_OP_DUP: - func = ggml_qnn_dup; - break; - case GGML_OP_ACC: - func = ggml_qnn_acc; - break; - case GGML_OP_DIV: - func = ggml_qnn_div; - break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(tensor)) { - case GGML_UNARY_OP_GELU: - func = ggml_qnn_gelu; - break; - case GGML_UNARY_OP_SILU: - func = ggml_qnn_silu; - break; - case GGML_UNARY_OP_GELU_QUICK: - func = ggml_qnn_gelu_quick; - break; - case GGML_UNARY_OP_TANH: - func = ggml_qnn_tanh; - break; - case GGML_UNARY_OP_RELU: - func = ggml_qnn_relu; - break; - case GGML_UNARY_OP_HARDSIGMOID: - func = ggml_qnn_hardsigmoid; - break; - case GGML_UNARY_OP_HARDSWISH: - func = ggml_qnn_hardswish; - break; - default: - return false; - } - break; - case GGML_OP_NORM: - func = ggml_qnn_norm; - break; - case GGML_OP_GROUP_NORM: - func = ggml_qnn_group_norm; - break; - case GGML_OP_CONCAT: - func = ggml_qnn_concat; - break; - case GGML_OP_UPSCALE: - func = ggml_qnn_upscale; - break; - case GGML_OP_PAD: - func = ggml_qnn_pad; - break; - case GGML_OP_LEAKY_RELU: - func = ggml_qnn_leaky_relu; - break; - case GGML_OP_RMS_NORM: - func = ggml_qnn_rms_norm; - break; - case GGML_OP_MUL_MAT_ID: - func = ggml_qnn_mul_mat_id; - break; - case GGML_OP_SCALE: - func = ggml_qnn_scale; - break; - case GGML_OP_SQR: - func = ggml_qnn_sqr; - break; - case GGML_OP_CLAMP: - func = ggml_qnn_clamp; - break; - case GGML_OP_CPY: - func = ggml_qnn_cpy; - break; - case GGML_OP_CONT: - func = ggml_qnn_dup; - break; - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - func = ggml_qnn_nop; - break; - case GGML_OP_DIAG_MASK_INF: - func = ggml_qnn_diag_mask_inf; - break; - case GGML_OP_SOFT_MAX: - func = ggml_qnn_soft_max; - break; - case GGML_OP_ROPE: - func = ggml_qnn_rope; - break; - case GGML_OP_IM2COL: - func = ggml_qnn_im2col; - break; - case GGML_OP_POOL_2D: - func = ggml_qnn_pool2d; - break; - case GGML_OP_SUM_ROWS: - func = ggml_qnn_sum_rows; - break; - case GGML_OP_ARGSORT: - func = ggml_qnn_argsort; - break; - default: + ggml_qnn_func_t func = s_op_table[tensor->op]; + if (!func) { + QNN_LOG_WARN("unsupported op %d", tensor->op); return false; } - if (nullptr != func) { - func(ctx, tensor->src[0], tensor->src[1], tensor); - } - + func(ctx, tensor->src[0], tensor->src[1], tensor); return true; } @@ -1349,7 +1318,7 @@ GGML_CALL static bool ggml_backend_qnn_supports_op(ggml_backend_t backend, GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend,const ggml_tensor * tensor) { ggml_backend_qnn_context * ctx = (ggml_backend_qnn_context *) backend->context; - return ggml_qnn_compute_forward(ctx, nullptr, (ggml_tensor *) tensor); + return ggml_qnn_can_handle_op(ctx, tensor, false); } static ggml_backend_i ggml_backend_qnn_interface = {