add function to get graph from cache

This commit is contained in:
hongruichen 2024-07-10 23:00:31 +08:00
parent 80051cfc4d
commit b6f29273f0
2 changed files with 47 additions and 61 deletions

View file

@ -108,6 +108,41 @@ bool execute_graph(qnn::ggml_qnn_graph<_InputSize, _OutputSize> *graph,
return true;
}
template <size_t _InputSize, size_t _OutputSize>
qnn::ggml_qnn_graph_binary *get_qnn_graph_from_cache(ggml_backend_qnn_context *ctx, ggml_op op,
const std::string &qnn_op,
const std::array<const ggml_tensor *, _InputSize> &inputs,
const std::array<ggml_tensor *, _OutputSize> &outputs) {
const std::string graph_key(ggml_op_name(op));
auto it = ctx->qnn_binary_graph_cache.find(graph_key);
qnn::ggml_qnn_graph_binary *graph_ptr = nullptr;
if (it != ctx->qnn_binary_graph_cache.end()) {
graph_ptr = it->second.get();
} else {
std::string graph_name = graph_key + "_" + std::to_string(ctx->threads);
for (auto &input: inputs) {
graph_name += "_";
graph_name += input->name;
}
auto graph = std::make_unique<qnn::ggml_qnn_graph_binary>(graph_name, (QNNBackend)(ctx->device),
ctx->instance->get_qnn_context_handle(),
ctx->raw_interface, ctx->socinfo.vtcm_size_in_mb);
if (!graph->is_valid()) {
return nullptr;
}
if (!qnn_bind_tensors_to_graph<2, 1>(graph.get(), qnn_op.c_str(), inputs, outputs)) {
return nullptr;
}
graph_ptr = graph.get();
ctx->qnn_binary_graph_cache[graph_key] = std::move(graph);
}
return graph_ptr;
}
} // namespace
#ifndef NDEBUG
@ -126,44 +161,21 @@ static void ggml_qnn_add(ggml_backend_qnn_context *ctx, const ggml_tensor *src0,
ggml_tensor *dst) {
CHECK_PARAMS(ctx, src0, src1, dst);
std::string graph_name = "ggml_op_qnn_add";
qnn::qnn_perf perf(graph_name);
qnn::qnn_perf perf("ggml_op_qnn_add");
perf.start();
bool succeed = false;
std::string graph_key(ggml_op_name(GGML_OP_ADD));
auto it = ctx->qnn_binary_graph_cache.find(graph_key);
qnn::ggml_qnn_graph_binary *graph_ptr = nullptr;
if (it != ctx->qnn_binary_graph_cache.end()) {
graph_ptr = it->second.get();
} else {
graph_name = graph_name + "_" + std::to_string(ctx->threads) + "_" + src0->name + "_" + src1->name;
auto graph = std::make_unique<qnn::ggml_qnn_graph_binary>(graph_name, (QNNBackend)(ctx->device),
ctx->instance->get_qnn_context_handle(),
ctx->raw_interface, ctx->socinfo.vtcm_size_in_mb);
if (!graph->is_valid()) {
goto failure;
}
if (!qnn_bind_tensors_to_graph<2, 1>(graph.get(), QNN_OP_ELEMENT_WISE_ADD, { src0, src1 }, { dst })) {
goto failure;
}
graph_ptr = graph.get();
ctx->qnn_binary_graph_cache[graph_key] = std::move(graph);
qnn::ggml_qnn_graph_binary *graph_ptr =
get_qnn_graph_from_cache<2, 1>(ctx, GGML_OP_ADD, QNN_OP_ELEMENT_WISE_ADD, { src0, src1 }, { dst });
if (graph_ptr) {
succeed = execute_graph<2, 1>(graph_ptr, { src0, src1 }, { dst });
}
succeed = execute_graph<2, 1>(graph_ptr, { src0, src1 }, { dst });
failure:
if (!succeed) {
print_ggml_tensor(src0);
print_ggml_tensor(src1);
print_ggml_tensor(dst);
}
perf.info();
}
/*
@ -181,49 +193,21 @@ static void ggml_qnn_mul_mat(ggml_backend_qnn_context *ctx, const ggml_tensor *s
ggml_tensor *dst) {
CHECK_PARAMS(ctx, src0, src1, dst);
std::string graph_name = "ggml_op_qnn_mul_mat";
qnn::qnn_perf perf(graph_name);
qnn::qnn_perf perf("ggml_op_qnn_mul_mat");
perf.start();
// TODO: for scenarios of quantized data in src0
// pass-1: dequantize src0 to FP32
// pass-2: dq-src0 * src1
// the performance gains is worth although there is performance loss in pass-1
bool succeed = false;
std::string graph_key(ggml_op_name(GGML_OP_MUL_MAT));
auto it = ctx->qnn_binary_graph_cache.find(graph_key);
qnn::ggml_qnn_graph_binary *graph_ptr = nullptr;
if (it != ctx->qnn_binary_graph_cache.end()) {
graph_ptr = it->second.get();
} else {
graph_name = graph_name + "_" + std::to_string(ctx->threads) + "_" + src0->name + "_" + src1->name;
auto graph = std::make_unique<qnn::ggml_qnn_graph_binary>(graph_name, (QNNBackend)(ctx->device),
ctx->instance->get_qnn_context_handle(),
ctx->raw_interface, ctx->socinfo.vtcm_size_in_mb);
if (!graph->is_valid()) {
goto failure;
}
if (!qnn_bind_tensors_to_graph<2, 1>(graph.get(), QNN_OP_MAT_MUL, { src0, src1 }, { dst })) {
goto failure;
}
graph_ptr = graph.get();
ctx->qnn_binary_graph_cache[graph_key] = std::move(graph);
qnn::ggml_qnn_graph_binary *graph_ptr =
get_qnn_graph_from_cache<2, 1>(ctx, GGML_OP_MUL_MAT, QNN_OP_MAT_MUL, { src0, src1 }, { dst });
if (graph_ptr) {
succeed = execute_graph<2, 1>(graph_ptr, { src0, src1 }, { dst });
}
succeed = execute_graph<2, 1>(graph_ptr, { src0, src1 }, { dst });
failure:
if (!succeed) {
print_ggml_tensor(src0);
print_ggml_tensor(src1);
print_ggml_tensor(dst);
}
perf.info();
}
static void ggml_qnn_repeat(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,

View file

@ -189,6 +189,7 @@ void device_tensor_free(Qnn_Tensor_t &tensor);
class qnn_perf {
public:
qnn_perf(const std::string &perf_name) : _perf_name(std::move(perf_name)) {};
~qnn_perf() { info(); }
qnn_perf() = delete;
qnn_perf(const qnn_perf &) = delete;
qnn_perf &operator=(const qnn_perf &) = delete;
@ -211,6 +212,7 @@ private:
class qnn_perf {
public:
qnn_perf(const std::string &perf_name) {}
~qnn_perf() { info(); }
qnn_perf() = delete;
qnn_perf(const qnn_perf &) = delete;
qnn_perf &operator=(const qnn_perf &) = delete;