fix: try fix graph cache with append the tensors name
This commit is contained in:
parent
51f95d6980
commit
5f3b1ae3b0
2 changed files with 18 additions and 8 deletions
|
@ -96,9 +96,14 @@ public:
|
|||
bool is_valid() const { return _buffer != nullptr; }
|
||||
|
||||
bool init_tensor(ggml_tensor *tensor) {
|
||||
if (qnn::ggml_qnn_tensor::from_ggml_tensor(tensor)) {
|
||||
QNN_LOG_INFO("tensor %s already initialized", tensor->name);
|
||||
return true;
|
||||
}
|
||||
|
||||
auto qnn_tensor = std::make_unique<qnn::ggml_qnn_tensor>(tensor, _device, _instance);
|
||||
if (!qnn_tensor->is_valid()) {
|
||||
QNN_LOG_WARN("Create ggml_qnn_tensor failed");
|
||||
QNN_LOG_WARN("create ggml_qnn_tensor failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -167,19 +167,23 @@ qnn::ggml_qnn_graph<_InputSize, _OutputSize> *get_qnn_graph_from_cache(
|
|||
auto &graph_cache = get_qnn_graph_cache(ctx, inputs, outputs);
|
||||
const auto *op_name = op < qnn::kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op))
|
||||
: ggml_unary_op_name(ggml_unary_op(op - qnn::kGgmlUnaryOpStart));
|
||||
const std::string graph_key(op_name);
|
||||
std::string graph_key(op_name);
|
||||
for (auto &input : inputs) {
|
||||
graph_key += "_";
|
||||
graph_key += input->name;
|
||||
}
|
||||
for (auto &output : outputs) {
|
||||
graph_key += "_";
|
||||
graph_key += output->name;
|
||||
}
|
||||
|
||||
auto it = graph_cache.find(graph_key);
|
||||
graph_t *graph_ptr = nullptr;
|
||||
if (it != graph_cache.end()) {
|
||||
graph_ptr = it->second.get();
|
||||
} else {
|
||||
std::string graph_name = graph_key + "_" + std::to_string(ctx->threads);
|
||||
for (auto &input : inputs) {
|
||||
graph_name += "_";
|
||||
graph_name += input->name;
|
||||
}
|
||||
auto graph =
|
||||
std::make_unique<graph_t>(graph_name, (QNNBackend)(ctx->device), ctx->instance->get_qnn_context_handle(),
|
||||
std::make_unique<graph_t>(graph_key, (QNNBackend)(ctx->device), ctx->instance->get_qnn_context_handle(),
|
||||
ctx->qnn_interface, ctx->socinfo.vtcm_size_in_mb);
|
||||
|
||||
if (!graph->is_valid()) {
|
||||
|
@ -187,6 +191,7 @@ qnn::ggml_qnn_graph<_InputSize, _OutputSize> *get_qnn_graph_from_cache(
|
|||
}
|
||||
|
||||
if (!qnn_bind_tensors_to_graph<_InputSize, _OutputSize>(graph.get(), qnn_op.c_str(), inputs, outputs)) {
|
||||
QNN_LOG_ERROR("qnn_bind_tensors_to_graph failed\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue