fix: try fix error in 2nd run by appending dimension into graph key
This commit is contained in:
parent
ee305cc171
commit
47735cb589
1 changed files with 26 additions and 10 deletions
|
@ -78,6 +78,31 @@ bool execute_graph(qnn::ggml_qnn_graph *graph, const std::array<ggml_tensor *, _
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <size_t _InputSize, size_t _OutputSize>
|
||||||
|
std::string get_graph_key(const std::string &op_name, const std::array<ggml_tensor *, _InputSize> &inputs,
|
||||||
|
const std::array<ggml_tensor *, _OutputSize> &outputs) {
|
||||||
|
constexpr static const auto append_dimensions = [](std::string &key, const ggml_tensor *tensor) {
|
||||||
|
key += "_";
|
||||||
|
key += std::to_string(tensor->ne[0]);
|
||||||
|
key += "x";
|
||||||
|
key += std::to_string(tensor->ne[1]);
|
||||||
|
key += "x";
|
||||||
|
key += std::to_string(tensor->ne[2]);
|
||||||
|
key += "x";
|
||||||
|
key += std::to_string(tensor->ne[3]);
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string graph_key(op_name);
|
||||||
|
for (auto &input : inputs) {
|
||||||
|
append_dimensions(graph_key, input);
|
||||||
|
}
|
||||||
|
for (auto &output : outputs) {
|
||||||
|
append_dimensions(graph_key, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
return graph_key;
|
||||||
|
}
|
||||||
|
|
||||||
template <size_t _InputSize, size_t _OutputSize>
|
template <size_t _InputSize, size_t _OutputSize>
|
||||||
qnn::ggml_qnn_graph *get_qnn_graph_from_cache(ggml_backend_qnn_context *ctx, size_t op, const std::string &qnn_op,
|
qnn::ggml_qnn_graph *get_qnn_graph_from_cache(ggml_backend_qnn_context *ctx, size_t op, const std::string &qnn_op,
|
||||||
const std::array<ggml_tensor *, _InputSize> &inputs,
|
const std::array<ggml_tensor *, _InputSize> &inputs,
|
||||||
|
@ -87,16 +112,7 @@ qnn::ggml_qnn_graph *get_qnn_graph_from_cache(ggml_backend_qnn_context *ctx, siz
|
||||||
auto &graph_cache = ctx->qnn_graph_cache;
|
auto &graph_cache = ctx->qnn_graph_cache;
|
||||||
const auto *op_name = op < qnn::kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op))
|
const auto *op_name = op < qnn::kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op))
|
||||||
: ggml_unary_op_name(ggml_unary_op(op - qnn::kGgmlUnaryOpStart));
|
: ggml_unary_op_name(ggml_unary_op(op - qnn::kGgmlUnaryOpStart));
|
||||||
std::string graph_key(op_name);
|
auto graph_key = get_graph_key<_InputSize, _OutputSize>(op_name, inputs, outputs);
|
||||||
for (auto &input : inputs) {
|
|
||||||
graph_key += "_";
|
|
||||||
graph_key += input->name;
|
|
||||||
}
|
|
||||||
for (auto &output : outputs) {
|
|
||||||
graph_key += "_";
|
|
||||||
graph_key += output->name;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto it = graph_cache.find(graph_key);
|
auto it = graph_cache.find(graph_key);
|
||||||
qnn::ggml_qnn_graph *graph_ptr = nullptr;
|
qnn::ggml_qnn_graph *graph_ptr = nullptr;
|
||||||
if (it != graph_cache.end()) {
|
if (it != graph_cache.end()) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue