fix op handle checker

This commit is contained in:
hongruichen 2024-06-24 12:06:42 +08:00
parent ff0359d6f4
commit e1056da1c0

View file

@ -354,12 +354,100 @@ static int free_qnn_tensor(Qnn_Tensor_t & tensor) {
// implementation of QNN backend for GGML
//
// =================================================================================================
static void ggml_qnn_add(ggml_backend_qnn_context * ctx, const ggml_tensor * src0,
const ggml_tensor * src1, ggml_tensor * dst);
static void ggml_qnn_mul_mat(ggml_backend_qnn_context * ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
static ggml_qnn_func_t s_op_table[GGML_OP_COUNT] = {
nullptr, // GGML_OP_NONE
nullptr, // GGML_OP_DUP
ggml_qnn_add, // GGML_OP_ADD
nullptr, // GGML_OP_ADD1
nullptr, // GGML_OP_ACC
nullptr, // GGML_OP_SUB
nullptr, // GGML_OP_MUL
nullptr, // GGML_OP_DIV
nullptr, // GGML_OP_SQR
nullptr, // GGML_OP_SQRT
nullptr, // GGML_OP_LOG
nullptr, // GGML_OP_SUM
nullptr, // GGML_OP_SUM_ROWS
nullptr, // GGML_OP_MEAN
nullptr, // GGML_OP_ARGMAX
nullptr, // GGML_OP_REPEAT
nullptr, // GGML_OP_REPEAT_BACK
nullptr, // GGML_OP_CONCAT
nullptr, // GGML_OP_SILU_BACK
nullptr, // GGML_OP_NORM
nullptr, // GGML_OP_RMS_NORM
nullptr, // GGML_OP_RMS_NORM_BACK
nullptr, // GGML_OP_GROUP_NORM
ggml_qnn_mul_mat, // GGML_OP_MUL_MAT
nullptr, // GGML_OP_MUL_MAT_ID
nullptr, // GGML_OP_OUT_PROD
nullptr, // GGML_OP_SCALE
nullptr, // GGML_OP_SET
nullptr, // GGML_OP_CPY
nullptr, // GGML_OP_CONT
nullptr, // GGML_OP_RESHAPE
nullptr, // GGML_OP_VIEW
nullptr, // GGML_OP_PERMUTE
nullptr, // GGML_OP_TRANSPOSE
nullptr, // GGML_OP_GET_ROWS
nullptr, // GGML_OP_GET_ROWS_BACK
nullptr, // GGML_OP_DIAG
nullptr, // GGML_OP_DIAG_MASK_INF
nullptr, // GGML_OP_DIAG_MASK_ZERO
nullptr, // GGML_OP_SOFT_MAX
nullptr, // GGML_OP_SOFT_MAX_BACK
nullptr, // GGML_OP_ROPE
nullptr, // GGML_OP_ROPE_BACK
nullptr, // GGML_OP_CLAMP
nullptr, // GGML_OP_CONV_TRANSPOSE_1D
nullptr, // GGML_OP_IM2COL
nullptr, // GGML_OP_CONV_TRANSPOSE_2D
nullptr, // GGML_OP_POOL_1D
nullptr, // GGML_OP_POOL_2D
nullptr, // GGML_OP_UPSCALE
nullptr, // GGML_OP_PAD
nullptr, // GGML_OP_ARANGE
nullptr, // GGML_OP_TIMESTEP_EMBEDDING
nullptr, // GGML_OP_ARGSORT
nullptr, // GGML_OP_LEAKY_RELU
nullptr, // GGML_OP_FLASH_ATTN_EXT
nullptr, // GGML_OP_FLASH_ATTN_BACK
nullptr, // GGML_OP_SSM_CONV
nullptr, // GGML_OP_SSM_SCAN
nullptr, // GGML_OP_WIN_PART
nullptr, // GGML_OP_WIN_UNPART
nullptr, // GGML_OP_GET_REL_POS
nullptr, // GGML_OP_ADD_REL_POS
nullptr, // GGML_OP_UNARY
nullptr, // GGML_OP_MAP_UNARY
nullptr, // GGML_OP_MAP_BINARY
nullptr, // GGML_OP_MAP_CUSTOM1_F32
nullptr, // GGML_OP_MAP_CUSTOM2_F32
nullptr, // GGML_OP_MAP_CUSTOM3_F32
nullptr, // GGML_OP_MAP_CUSTOM1
nullptr, // GGML_OP_MAP_CUSTOM2
nullptr, // GGML_OP_MAP_CUSTOM3
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS
nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK
};
static bool ggml_qnn_can_handle_op(ggml_backend_qnn_context * ctx,
const struct ggml_tensor * tensor,
bool b_dump_tensor_info) {
if (ggml_is_empty(tensor) || tensor->op == GGML_OP_RESHAPE ||
tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_VIEW ||
tensor->op == GGML_OP_PERMUTE || tensor->op == GGML_OP_NONE) {
if (ggml_is_empty(tensor) || !s_op_table[tensor->op]) {
return false;
}
@ -369,10 +457,10 @@ static bool ggml_qnn_can_handle_op(ggml_backend_qnn_context * ctx,
return false;
}
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const auto ne00 = src0->ne[0];
const auto ne01 = src0->ne[1];
const auto ne10 = src1->ne[0];
const auto ne11 = src1->ne[1];
// make qnn_get_ggml_tensor_rank and QNN SDK happy
if (ne00 <= 1 || ne01 <= 1 || ne10 <= 1 || ne11 <= 1) {
return false;
@ -951,132 +1039,13 @@ static void ggml_qnn_nop(ggml_backend_qnn_context * ctx, const ggml_tensor * src
bool ggml_qnn_compute_forward(ggml_backend_qnn_context * ctx,
struct ggml_compute_params * params,
struct ggml_tensor * tensor) {
ggml_qnn_func_t func = nullptr;
switch (tensor->op) {
case GGML_OP_ADD:
func = ggml_qnn_add;
break;
case GGML_OP_MUL_MAT:
func = ggml_qnn_mul_mat;
break;
case GGML_OP_REPEAT:
func = ggml_qnn_repeat;
break;
case GGML_OP_GET_ROWS:
func = ggml_qnn_get_rows;
break;
case GGML_OP_DUP:
func = ggml_qnn_dup;
break;
case GGML_OP_ACC:
func = ggml_qnn_acc;
break;
case GGML_OP_DIV:
func = ggml_qnn_div;
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_GELU:
func = ggml_qnn_gelu;
break;
case GGML_UNARY_OP_SILU:
func = ggml_qnn_silu;
break;
case GGML_UNARY_OP_GELU_QUICK:
func = ggml_qnn_gelu_quick;
break;
case GGML_UNARY_OP_TANH:
func = ggml_qnn_tanh;
break;
case GGML_UNARY_OP_RELU:
func = ggml_qnn_relu;
break;
case GGML_UNARY_OP_HARDSIGMOID:
func = ggml_qnn_hardsigmoid;
break;
case GGML_UNARY_OP_HARDSWISH:
func = ggml_qnn_hardswish;
break;
default:
return false;
}
break;
case GGML_OP_NORM:
func = ggml_qnn_norm;
break;
case GGML_OP_GROUP_NORM:
func = ggml_qnn_group_norm;
break;
case GGML_OP_CONCAT:
func = ggml_qnn_concat;
break;
case GGML_OP_UPSCALE:
func = ggml_qnn_upscale;
break;
case GGML_OP_PAD:
func = ggml_qnn_pad;
break;
case GGML_OP_LEAKY_RELU:
func = ggml_qnn_leaky_relu;
break;
case GGML_OP_RMS_NORM:
func = ggml_qnn_rms_norm;
break;
case GGML_OP_MUL_MAT_ID:
func = ggml_qnn_mul_mat_id;
break;
case GGML_OP_SCALE:
func = ggml_qnn_scale;
break;
case GGML_OP_SQR:
func = ggml_qnn_sqr;
break;
case GGML_OP_CLAMP:
func = ggml_qnn_clamp;
break;
case GGML_OP_CPY:
func = ggml_qnn_cpy;
break;
case GGML_OP_CONT:
func = ggml_qnn_dup;
break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
func = ggml_qnn_nop;
break;
case GGML_OP_DIAG_MASK_INF:
func = ggml_qnn_diag_mask_inf;
break;
case GGML_OP_SOFT_MAX:
func = ggml_qnn_soft_max;
break;
case GGML_OP_ROPE:
func = ggml_qnn_rope;
break;
case GGML_OP_IM2COL:
func = ggml_qnn_im2col;
break;
case GGML_OP_POOL_2D:
func = ggml_qnn_pool2d;
break;
case GGML_OP_SUM_ROWS:
func = ggml_qnn_sum_rows;
break;
case GGML_OP_ARGSORT:
func = ggml_qnn_argsort;
break;
default:
ggml_qnn_func_t func = s_op_table[tensor->op];
if (!func) {
QNN_LOG_WARN("unsupported op %d", tensor->op);
return false;
}
if (nullptr != func) {
func(ctx, tensor->src[0], tensor->src[1], tensor);
}
func(ctx, tensor->src[0], tensor->src[1], tensor);
return true;
}
@ -1349,7 +1318,7 @@ GGML_CALL static bool ggml_backend_qnn_supports_op(ggml_backend_t backend,
GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend,const ggml_tensor * tensor) {
ggml_backend_qnn_context * ctx = (ggml_backend_qnn_context *) backend->context;
return ggml_qnn_compute_forward(ctx, nullptr, (ggml_tensor *) tensor);
return ggml_qnn_can_handle_op(ctx, tensor, false);
}
static ggml_backend_i ggml_backend_qnn_interface = {