This commit is contained in:
hongruichen 2024-06-17 18:44:19 +08:00
parent 5fe7b87ba1
commit 9456bba121

View file

@ -2077,8 +2077,8 @@ private:
void operator=(ggml_qnn_tensor_readwrite&&) = delete;
};
using ggml_qnn_tensor_reader = ggml_qnn_tensor_readwrite<QNN_TENSOR_TYPE_APP_READ>;
using ggml_qnn_tensor_writer = ggml_qnn_tensor_readwrite<QNN_TENSOR_TYPE_APP_WRITE>;
using ggml_qnn_tensor_output = ggml_qnn_tensor_readwrite<QNN_TENSOR_TYPE_APP_READ>;
using ggml_qnn_tensor_input = ggml_qnn_tensor_readwrite<QNN_TENSOR_TYPE_APP_WRITE>;
//TODO: this function can be removed later because there are duplicated codes with ggml_qnn_mul_mat
// keep it for illustrate how to implement a specified GGMPL OP using QNN API + QNN RPC
@ -2164,22 +2164,22 @@ static void ggml_qnn_add(ggml_backend_qnn_context * ctx, const ggml_tensor * src
QNN_LOG_INFO("create qnn graph handle with graph name %s ok\n", graph_name.c_str());
}
ggml_qnn_tensor_writer tensor_writer0(src0, graph_handle, ctx);
if (!tensor_writer0.is_valid()) {
ggml_qnn_tensor_input tensor_input0(src0, graph_handle, ctx);
if (!tensor_input0.is_valid()) {
goto failure;
}
ggml_qnn_tensor_writer tensor_writer1(src1, graph_handle, ctx);
if (!tensor_writer1.is_valid()) {
ggml_qnn_tensor_input tensor_input1(src1, graph_handle, ctx);
if (!tensor_input1.is_valid()) {
QNN_LOG_INFO("error = %d\n", error);
goto failure;
}
ggml_qnn_tensor_reader tensor_reader(dst, graph_handle, ctx);
if (!tensor_reader.is_valid()) {
ggml_qnn_tensor_output tensor_output(dst, graph_handle, ctx);
if (!tensor_output.is_valid()) {
goto failure;
}
Qnn_Tensor_t tensor_inputs[] = {*tensor_writer0.get_qnn_tensor(), *tensor_writer1.get_qnn_tensor()};
Qnn_Tensor_t tensor_outputs[] = {*tensor_reader.get_qnn_tensor()};
Qnn_Tensor_t tensor_inputs[] = {*tensor_input0.get_qnn_tensor(), *tensor_input1.get_qnn_tensor()};
Qnn_Tensor_t tensor_outputs[] = {*tensor_output.get_qnn_tensor()};
Qnn_OpConfig_t op_config = {
(Qnn_OpConfigVersion_t) 1,
.v1 = {"ggml_op_add",
@ -2215,18 +2215,18 @@ static void ggml_qnn_add(ggml_backend_qnn_context * ctx, const ggml_tensor * src
}
auto graph_item = std::make_tuple(graph_handle,
tensor_writer0.get_qnn_tensor(),
tensor_writer1.get_qnn_tensor(),
tensor_reader.get_qnn_tensor());
tensor_input0.get_qnn_tensor(),
tensor_input1.get_qnn_tensor(),
tensor_output.get_qnn_tensor());
instance->_qnn_graph_map[map_entry] = graph_item;
} else {
auto & graph_item = instance->_qnn_graph_map[map_entry];
ggml_qnn_tensor_writer tensor_writer0(src0, std::get<1>(graph_item), ctx);
ggml_qnn_tensor_writer tensor_writer1(src1, std::get<2>(graph_item), ctx);
ggml_qnn_tensor_reader tensor_reader(dst, std::get<3>(graph_item), ctx);
ggml_qnn_tensor_input tensor_input0(src0, std::get<1>(graph_item), ctx);
ggml_qnn_tensor_input tensor_input1(src1, std::get<2>(graph_item), ctx);
ggml_qnn_tensor_output tensor_output(dst, std::get<3>(graph_item), ctx);
Qnn_Tensor_t tensor_inputs[] = {*tensor_writer0.get_qnn_tensor(), *tensor_writer1.get_qnn_tensor()};
Qnn_Tensor_t tensor_outputs[] = {*tensor_reader.get_qnn_tensor()};
Qnn_Tensor_t tensor_inputs[] = {*tensor_input0.get_qnn_tensor(), *tensor_input1.get_qnn_tensor()};
Qnn_Tensor_t tensor_outputs[] = {*tensor_output.get_qnn_tensor()};
error = qnn_raw_interface.graphExecute(graph_handle,
tensor_inputs,2,
tensor_outputs,1,
@ -2360,21 +2360,21 @@ static void ggml_qnn_mul_mat(ggml_backend_qnn_context * ctx,
goto failure;
}
ggml_qnn_tensor_writer tensor_writer0(src0, graph_handle, ctx);
if (!tensor_writer0.is_valid()) {
ggml_qnn_tensor_input tensor_input0(src0, graph_handle, ctx);
if (!tensor_input0.is_valid()) {
goto failure;
}
ggml_qnn_tensor_writer tensor_writer1(src1, graph_handle, ctx);
if (!tensor_writer1.is_valid()) {
ggml_qnn_tensor_input tensor_input1(src1, graph_handle, ctx);
if (!tensor_input1.is_valid()) {
goto failure;
}
ggml_qnn_tensor_reader tensor_reader(dst, graph_handle, ctx);
if (!tensor_reader.is_valid()) {
ggml_qnn_tensor_output tensor_output(dst, graph_handle, ctx);
if (!tensor_output.is_valid()) {
goto failure;
}
Qnn_Tensor_t tensor_inputs[] = {*tensor_writer0.get_qnn_tensor(), *tensor_writer1.get_qnn_tensor()};
Qnn_Tensor_t tensor_outputs[] = {*tensor_reader.get_qnn_tensor()};
Qnn_Tensor_t tensor_inputs[] = {*tensor_input0.get_qnn_tensor(), *tensor_input1.get_qnn_tensor()};
Qnn_Tensor_t tensor_outputs[] = {*tensor_output.get_qnn_tensor()};
Qnn_OpConfig_t op_config = {
(Qnn_OpConfigVersion_t) 1,
.v1 = {"ggml_op_mul_mat",
@ -2410,18 +2410,18 @@ static void ggml_qnn_mul_mat(ggml_backend_qnn_context * ctx,
}
auto graph_item = std::make_tuple(graph_handle,
tensor_writer0.get_qnn_tensor(),
tensor_writer1.get_qnn_tensor(),
tensor_reader.get_qnn_tensor());
tensor_input0.get_qnn_tensor(),
tensor_input1.get_qnn_tensor(),
tensor_output.get_qnn_tensor());
instance->_qnn_graph_map[map_entry] = graph_item;
} else {
auto & graph_item= instance->_qnn_graph_map[map_entry];
ggml_qnn_tensor_writer tensor_writer0(src0, std::get<1>(graph_item), ctx);
ggml_qnn_tensor_writer tensor_writer1(src1, std::get<2>(graph_item), ctx);
ggml_qnn_tensor_reader tensor_reader(dst, std::get<3>(graph_item), ctx);
ggml_qnn_tensor_input tensor_input0(src0, std::get<1>(graph_item), ctx);
ggml_qnn_tensor_input tensor_input1(src1, std::get<2>(graph_item), ctx);
ggml_qnn_tensor_output tensor_output(dst, std::get<3>(graph_item), ctx);
Qnn_Tensor_t tensor_inputs[] = {*tensor_writer0.get_qnn_tensor(), *tensor_writer1.get_qnn_tensor()};
Qnn_Tensor_t tensor_outputs[] = {*tensor_reader.get_qnn_tensor()};
Qnn_Tensor_t tensor_inputs[] = {*tensor_input0.get_qnn_tensor(), *tensor_input1.get_qnn_tensor()};
Qnn_Tensor_t tensor_outputs[] = {*tensor_output.get_qnn_tensor()};
error = qnn_raw_interface.graphExecute(graph_handle,
tensor_inputs, 2,
tensor_outputs, 1,