fix: try fix graph cache with append the tensors name

This commit is contained in:
hongruichen 2024-07-20 16:21:09 +08:00
parent 51f95d6980
commit 5f3b1ae3b0
2 changed files with 18 additions and 8 deletions

View file

@ -96,9 +96,14 @@ public:
bool is_valid() const { return _buffer != nullptr; }
bool init_tensor(ggml_tensor *tensor) {
if (qnn::ggml_qnn_tensor::from_ggml_tensor(tensor)) {
QNN_LOG_INFO("tensor %s already initialized", tensor->name);
return true;
}
auto qnn_tensor = std::make_unique<qnn::ggml_qnn_tensor>(tensor, _device, _instance);
if (!qnn_tensor->is_valid()) {
QNN_LOG_WARN("Create ggml_qnn_tensor failed");
QNN_LOG_WARN("create ggml_qnn_tensor failed");
return false;
}

View file

@ -167,19 +167,23 @@ qnn::ggml_qnn_graph<_InputSize, _OutputSize> *get_qnn_graph_from_cache(
auto &graph_cache = get_qnn_graph_cache(ctx, inputs, outputs);
const auto *op_name = op < qnn::kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op))
: ggml_unary_op_name(ggml_unary_op(op - qnn::kGgmlUnaryOpStart));
const std::string graph_key(op_name);
std::string graph_key(op_name);
for (auto &input : inputs) {
graph_key += "_";
graph_key += input->name;
}
for (auto &output : outputs) {
graph_key += "_";
graph_key += output->name;
}
auto it = graph_cache.find(graph_key);
graph_t *graph_ptr = nullptr;
if (it != graph_cache.end()) {
graph_ptr = it->second.get();
} else {
std::string graph_name = graph_key + "_" + std::to_string(ctx->threads);
for (auto &input : inputs) {
graph_name += "_";
graph_name += input->name;
}
auto graph =
std::make_unique<graph_t>(graph_name, (QNNBackend)(ctx->device), ctx->instance->get_qnn_context_handle(),
std::make_unique<graph_t>(graph_key, (QNNBackend)(ctx->device), ctx->instance->get_qnn_context_handle(),
ctx->qnn_interface, ctx->socinfo.vtcm_size_in_mb);
if (!graph->is_valid()) {
@ -187,6 +191,7 @@ qnn::ggml_qnn_graph<_InputSize, _OutputSize> *get_qnn_graph_from_cache(
}
if (!qnn_bind_tensors_to_graph<_InputSize, _OutputSize>(graph.get(), qnn_op.c_str(), inputs, outputs)) {
QNN_LOG_ERROR("qnn_bind_tensors_to_graph failed\n");
return nullptr;
}