feat: check if dims equal for add

looks qnn add can only applied to matrix with equal dimensions
This commit is contained in:
hongruichen 2024-07-27 13:31:57 +08:00
parent 5da73f8085
commit e0c9b34016

View file

@ -43,6 +43,27 @@ bool qnn_is_valid_params(ggml_backend_qnn_context *ctx, const ggml_tensor *src0,
return true;
}
bool is_tensor_dimensions_equal(const ggml_tensor *l, const ggml_tensor *r) {
const auto dim_l = ggml_n_dims(l);
if (dim_l != ggml_n_dims(r)) {
return false;
}
for (int i = 0; i < dim_l; i++) {
if (l->ne[i] != r->ne[i]) {
return false;
}
}
return true;
}
void print_ggml_tensor(const ggml_tensor *tensor) {
QNN_LOG_DEBUG("%15s: type = %i (%5s) ne = %5" PRIi64 " x %5" PRIi64 " x %5" PRIi64 ", nb = (%5zi, %5zi, %5zi)\n",
tensor->name, tensor->type, ggml_type_name(tensor->type), tensor->ne[0], tensor->ne[1], tensor->ne[2],
tensor->nb[0], tensor->nb[1], tensor->nb[2]);
}
} // namespace
#define CHECK_PARAMS(ctx, ...) \
@ -65,12 +86,6 @@ typedef const ggml_qnn_binary_op_t (&ggml_qnn_binary_op_array_t)[GGML_OP_COUNT];
constexpr const size_t kGgmlUnaryOpStart = GGML_OP_COUNT;
void print_ggml_tensor(const ggml_tensor *tensor) {
QNN_LOG_DEBUG("%15s: type = %i (%5s) ne = %5" PRIi64 " x %5" PRIi64 " x %5" PRIi64 ", nb = (%5zi, %5zi, %5zi)\n",
tensor->name, tensor->type, ggml_type_name(tensor->type), tensor->ne[0], tensor->ne[1], tensor->ne[2],
tensor->nb[0], tensor->nb[1], tensor->nb[2]);
}
template <size_t _Size>
qnn::ggml_tensor_array_t to_ggml_tensor_array(const std::array<ggml_tensor *, _Size> &array) {
return qnn::ggml_tensor_array_t(array.data(), array.data() + _Size);
@ -512,6 +527,11 @@ bool ggml_qnn_supports_op(const ggml_tensor *op) {
QNN_LOG_DEBUG("src0 or src1 is nullptr");
return false;
}
if (op->op == GGML_OP_ADD && !is_tensor_dimensions_equal(op->src[0], op->src[1])) {
QNN_LOG_DEBUG("src0 and src1 dimensions are not equal");
return false;
}
}
switch (op->type) {