ggml-qnn: remove static global vars to support multi-instance simultaneously
This commit is contained in:
parent
f4c53037ab
commit
2fab33d825
2 changed files with 113 additions and 140 deletions
230
ggml-qnn.cpp
230
ggml-qnn.cpp
|
@ -76,7 +76,7 @@ static void ggml_qnn_log_internal(ggml_log_level level, const char * file, const
|
|||
|
||||
#define GGML_QNN_LOGBUF_LEN 4096
|
||||
|
||||
#define GGML_QNN_DEBUG 0 //for troubleshooting QNN backend
|
||||
#define GGML_QNN_DEBUG 1 //for troubleshooting QNN backend
|
||||
|
||||
#define QNN_LOG_ERROR(...) ggml_qnn_log_internal(GGML_LOG_LEVEL_DEBUG, __FILE__, __FUNCTION__, __LINE__, __VA_ARGS__)
|
||||
#define QNN_LOG_WARN(...) ggml_qnn_log_internal(GGML_LOG_LEVEL_DEBUG , __FILE__, __FUNCTION__, __LINE__, __VA_ARGS__)
|
||||
|
@ -89,7 +89,7 @@ static void ggml_qnn_log_internal(ggml_log_level level, const char * file, const
|
|||
#endif
|
||||
|
||||
#define QNN_VER_PTR(x) (&((x).v1))
|
||||
|
||||
#define GGML_QNN_NAME "qnn"
|
||||
|
||||
#define VALIDATE(value, status) \
|
||||
do { \
|
||||
|
@ -135,8 +135,6 @@ using _pfn_QnnInterface_getProviders = decltype(QnnInterface_
|
|||
using _pfn_QnnSystemInterface_getProviders = decltype(QnnSystemInterface_getProviders);
|
||||
|
||||
|
||||
typedef void (* ggml_qnn_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
typedef void (* ggml_qnn_func_common_t)(const ggml_op ggml_op, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
enum class ggml_qnn_profile_level {
|
||||
profile_off = 0,
|
||||
|
@ -144,7 +142,6 @@ enum class ggml_qnn_profile_level {
|
|||
profile_detail = 2
|
||||
};
|
||||
|
||||
|
||||
struct ggml_backend_qnn_context {
|
||||
int device;
|
||||
int threads;
|
||||
|
@ -156,15 +153,16 @@ struct ggml_backend_qnn_context {
|
|||
QNN_SYSTEM_INTERFACE_VER_TYPE raw_system_interface;
|
||||
} ;
|
||||
|
||||
typedef void (* ggml_qnn_func_t)(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
typedef void (* ggml_qnn_func_common_t)(ggml_backend_qnn_context * ctx, const ggml_op ggml_op, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
// =================================================================================================
|
||||
//
|
||||
// static global variables
|
||||
//
|
||||
// =================================================================================================
|
||||
static ggml_backend_t g_qnn_backend = nullptr;
|
||||
|
||||
static int g_current_device = QNN_BACKEND_GGML;
|
||||
//static ggml_backend_t g_qnn_backend = nullptr;
|
||||
|
||||
//according to the QNN SDK Reference Guide,
|
||||
//CPU - Choose a non-quantized model. Quantized models are currently incompatible with the CPU backend
|
||||
|
@ -184,7 +182,6 @@ static struct ggml_backend_qnn_context g_qnn_mgr[GGML_QNN_MAX_DEVICES] = {
|
|||
[QNN_BACKEND_NPU] = {.device = 2, .threads = 1, .name = "qnn-npu", .lib = "libQnnHtp.so", .instance = nullptr, .backend = nullptr, .raw_interface = {}, .raw_system_interface = {}},
|
||||
};
|
||||
|
||||
|
||||
// =================================================================================================
|
||||
//
|
||||
// QNN helper functions and other internal helper functions
|
||||
|
@ -1010,7 +1007,7 @@ void qnn_instance::free_rpcmem(void * buf) {
|
|||
}
|
||||
|
||||
|
||||
int32_t qnn_instance::rpcmem_to_fd(void *buf) {
|
||||
int32_t qnn_instance::rpcmem_to_fd(void * buf) {
|
||||
int32_t mem_fd = -1;
|
||||
if (!is_rpcmem_initialized()) {
|
||||
QNN_LOG_WARN("rpc memory not initialized\n");
|
||||
|
@ -1168,33 +1165,6 @@ int qnn_instance::load_backend(std::string & lib_path, const QnnSaver_Config_t *
|
|||
_loaded_lib_handle[backend_id] = lib_handle;
|
||||
_backend_id = backend_id;
|
||||
|
||||
#if 0 //comment it for purpose of reduce size of APK
|
||||
QnnSaver_Config_t outputdir_cfg;
|
||||
outputdir_cfg.option = QNN_SAVER_CONFIG_OPTION_OUTPUT_DIRECTORY;
|
||||
outputdir_cfg.outputDirectory = "/data/local/tmp/";
|
||||
|
||||
QnnSaver_Config_t backendid_cfg;
|
||||
backendid_cfg.option = QNN_SAVER_CONFIG_OPTION_BACKEND_ID;
|
||||
backendid_cfg.backendId = _backend_id;
|
||||
const QnnSaver_Config_t *saverCfg[] = {&outputdir_cfg, &backendid_cfg, nullptr};
|
||||
if (0 == QnnSaver_initialize(saverCfg)) {
|
||||
QNN_LOG_INFO("QnnSaver_initialize successfully");
|
||||
} else {
|
||||
QNN_LOG_WARN("QnnSaver_initialize failure");
|
||||
}
|
||||
#endif
|
||||
auto saver_initialize = load_qnn_functionpointers<_pfn_QnnSaver_initialize *>(
|
||||
_loaded_lib_handle[backend_id], "QnnSaver_initialize");
|
||||
if (nullptr != saver_initialize) {
|
||||
error = saver_initialize(saver_config);
|
||||
if (error != QNN_SUCCESS) {
|
||||
QNN_LOG_WARN("failed to saver_initialize,error %d", QNN_GET_ERROR_CODE(error));
|
||||
return 7;
|
||||
}
|
||||
} else {
|
||||
QNN_LOG_WARN("saver_initialize is null\n");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -1345,14 +1315,15 @@ static void ggml_qnn_logcallback(const char * fmt,
|
|||
}
|
||||
|
||||
double ms = (double) timestamp / 1000000.0;
|
||||
|
||||
if (0) {
|
||||
#if GGML_QNN_DEBUG
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(log_mutex);
|
||||
|
||||
memset(s_ggml_qnn_logbuf, 0, GGML_QNN_LOGBUF_LEN);
|
||||
vsnprintf(reinterpret_cast<char *const>(s_ggml_qnn_logbuf), GGML_QNN_LOGBUF_LEN, fmt, argp);
|
||||
QNN_LOG_DEBUG("%8.1fms [%-7s] %s\n", ms, log_level_desc, s_ggml_qnn_logbuf);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
|
@ -1390,11 +1361,7 @@ int qnn_instance::qnn_init(const QnnSaver_Config_t ** saver_config) {
|
|||
|
||||
_qnn_interface.set_qnn_interface(_loaded_backend[backend_id]);
|
||||
|
||||
#if 1
|
||||
_qnn_interface.qnn_log_create(ggml_qnn_logcallback, _qnn_log_level, &_qnn_log_handle);
|
||||
#else
|
||||
_qnn_raw_interface.logCreate(ggml_qnn_logcallback, _qnn_log_level, &_qnn_log_handle);
|
||||
#endif
|
||||
if (nullptr == _qnn_log_handle) {
|
||||
QNN_LOG_WARN("why failed to initialize qnn log\n"); //NPU backend not work on Qualcomm SoC based low-end phone
|
||||
return 4;
|
||||
|
@ -1437,7 +1404,7 @@ int qnn_instance::qnn_init(const QnnSaver_Config_t ** saver_config) {
|
|||
if (QNN_PROFILE_NO_ERROR != _qnn_raw_interface.profileCreate(
|
||||
_qnn_backend_handle, QNN_PROFILE_LEVEL_BASIC, &_qnn_profile_handle)) {
|
||||
QNN_LOG_WARN("unable to create profile handle in the backend\n");
|
||||
return 7;
|
||||
return 6;
|
||||
} else {
|
||||
QNN_LOG_DEBUG("initialize qnn profile successfully\n");
|
||||
}
|
||||
|
@ -1456,7 +1423,7 @@ int qnn_instance::qnn_init(const QnnSaver_Config_t ** saver_config) {
|
|||
_rpc_lib_handle = dlopen("libcdsprpc.so", RTLD_NOW | RTLD_LOCAL);
|
||||
if (nullptr == _rpc_lib_handle) {
|
||||
QNN_LOG_WARN("failed to load qualcomm's rpc lib, error:%s\n", dlerror());
|
||||
return 9;
|
||||
return 8;
|
||||
} else {
|
||||
QNN_LOG_DEBUG("load rpcmem lib successfully\n");
|
||||
set_rpcmem_initialized(true);
|
||||
|
@ -1470,7 +1437,7 @@ int qnn_instance::qnn_init(const QnnSaver_Config_t ** saver_config) {
|
|||
|| nullptr == _pfn_rpc_mem_to_fd) {
|
||||
QNN_LOG_WARN("unable to access symbols in QNN RPC lib. dlerror(): %s", dlerror());
|
||||
dlclose(_rpc_lib_handle);
|
||||
return 10;
|
||||
return 9;
|
||||
}
|
||||
|
||||
if (nullptr != _pfn_rpc_mem_init) // make Qualcomm's SoC based low-end phone happy
|
||||
|
@ -1483,7 +1450,7 @@ int qnn_instance::qnn_init(const QnnSaver_Config_t ** saver_config) {
|
|||
&_qnn_context_handle);
|
||||
if (nullptr == _qnn_context_handle) {
|
||||
QNN_LOG_WARN("why failed to initialize qnn context\n");
|
||||
return 8;
|
||||
return 10;
|
||||
} else {
|
||||
QNN_LOG_DEBUG("initialize qnn context successfully\n");
|
||||
}
|
||||
|
@ -1695,7 +1662,7 @@ static bool ggml_qnn_can_handle_op(const struct ggml_tensor * tensor, bool b_dum
|
|||
}
|
||||
|
||||
|
||||
static void ggml_qnn_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_add(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
Qnn_ErrorHandle_t error = QNN_SUCCESS;
|
||||
bool graph_initialized = false;
|
||||
int64_t n_begin_time = 0LL;
|
||||
|
@ -1703,7 +1670,6 @@ static void ggml_qnn_add(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
int64_t n_duration = 0LL;
|
||||
|
||||
qnn_instance * instance = nullptr;
|
||||
struct ggml_backend_qnn_context * ctx = nullptr;
|
||||
|
||||
std::string graph_name = "ggml_op_qnn_add";
|
||||
Qnn_GraphHandle_t graph_handle = nullptr;
|
||||
|
@ -1727,7 +1693,6 @@ static void ggml_qnn_add(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
QNN_LOG_WARN("pls check why QNN tensor is null");
|
||||
return;
|
||||
}
|
||||
ctx = (struct ggml_backend_qnn_context *)g_qnn_backend->context;
|
||||
if (nullptr == ctx) {
|
||||
QNN_LOG_WARN("pls check why backend ctx is null");
|
||||
return;
|
||||
|
@ -1755,9 +1720,9 @@ static void ggml_qnn_add(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
dst->type, ggml_type_name(dst->type), dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0],
|
||||
dst->nb[1], dst->nb[2]);
|
||||
QNN_LOG_DEBUG("%d, %d, %d, %d", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
||||
QNN_LOG_DEBUG("tensor0 name %s", QNN_TENSOR_GET_NAME(tensor_0));
|
||||
QNN_LOG_DEBUG("tensor1 name %s", QNN_TENSOR_GET_NAME(tensor_1));
|
||||
QNN_LOG_DEBUG("tensor2 name %s", QNN_TENSOR_GET_NAME(tensor_2));
|
||||
QNN_LOG_DEBUG("tensor0 name %s", QNN_TENSOR_GET_NAME(*tensor_0));
|
||||
QNN_LOG_DEBUG("tensor1 name %s", QNN_TENSOR_GET_NAME(*tensor_1));
|
||||
QNN_LOG_DEBUG("tensor2 name %s", QNN_TENSOR_GET_NAME(*tensor_2));
|
||||
|
||||
QNN_VER_PTR(*tensor_0)->type = QNN_TENSOR_TYPE_APP_WRITE;
|
||||
QNN_VER_PTR(*tensor_1)->type = QNN_TENSOR_TYPE_APP_WRITE;
|
||||
|
@ -1918,7 +1883,7 @@ static void ggml_qnn_add(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
* mul_mat_f16_f32: src0 is F16 and src1 is F32.
|
||||
* mul_mat_q_f32: src0 is quantized (Q4_0, Q4_1, ...), and src1 is F32.
|
||||
*/
|
||||
static void ggml_qnn_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_mul_mat(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
Qnn_ErrorHandle_t error = QNN_SUCCESS;
|
||||
bool graph_initialized = false;
|
||||
int64_t n_begin_time = 0LL;
|
||||
|
@ -1926,7 +1891,6 @@ static void ggml_qnn_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1,
|
|||
int64_t n_duration = 0LL;
|
||||
|
||||
qnn_instance * instance = nullptr;
|
||||
struct ggml_backend_qnn_context * ctx = nullptr;
|
||||
|
||||
std::string graph_name = "ggml_op_qnn_mul_mat";
|
||||
Qnn_GraphHandle_t graph_handle = nullptr;
|
||||
|
@ -1952,7 +1916,6 @@ static void ggml_qnn_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1,
|
|||
QNN_LOG_WARN("pls check why QNN tensor is null");
|
||||
return;
|
||||
}
|
||||
ctx = (struct ggml_backend_qnn_context *)g_qnn_backend->context;
|
||||
if (nullptr == ctx) {
|
||||
QNN_LOG_WARN("pls check why backend ctx is null");
|
||||
return;
|
||||
|
@ -1979,9 +1942,9 @@ static void ggml_qnn_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1,
|
|||
dst->type, ggml_type_name(dst->type), dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0],
|
||||
dst->nb[1], dst->nb[2]);
|
||||
QNN_LOG_DEBUG("%d, %d, %d, %d", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
||||
QNN_LOG_DEBUG("tensor0 name %s", QNN_TENSOR_GET_NAME(tensor_0));
|
||||
QNN_LOG_DEBUG("tensor1 name %s", QNN_TENSOR_GET_NAME(tensor_1));
|
||||
QNN_LOG_DEBUG("tensor2 name %s", QNN_TENSOR_GET_NAME(tensor_2));
|
||||
QNN_LOG_DEBUG("tensor0 name %s", QNN_TENSOR_GET_NAME(*tensor_0));
|
||||
QNN_LOG_DEBUG("tensor1 name %s", QNN_TENSOR_GET_NAME(*tensor_1));
|
||||
QNN_LOG_DEBUG("tensor2 name %s", QNN_TENSOR_GET_NAME(*tensor_2));
|
||||
|
||||
QNN_VER_PTR(*tensor_0)->type = QNN_TENSOR_TYPE_APP_WRITE;
|
||||
QNN_VER_PTR(*tensor_1)->type = QNN_TENSOR_TYPE_APP_WRITE;
|
||||
|
@ -2129,7 +2092,7 @@ static void ggml_qnn_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1,
|
|||
|
||||
|
||||
//common function for GGML OPs using QNN API
|
||||
static void ggml_qnn_hanlde_op(const enum ggml_op ggmlop, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_hanlde_op(ggml_backend_qnn_context * ctx, const enum ggml_op ggmlop, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
Qnn_ErrorHandle_t error = QNN_SUCCESS;
|
||||
bool graph_initialized = false;
|
||||
int64_t n_begin_time = 0LL;
|
||||
|
@ -2137,7 +2100,6 @@ static void ggml_qnn_hanlde_op(const enum ggml_op ggmlop, const ggml_tensor * sr
|
|||
int64_t n_duration = 0LL;
|
||||
|
||||
qnn_instance * instance = nullptr;
|
||||
struct ggml_backend_qnn_context * ctx = nullptr;
|
||||
|
||||
std::string qnn_graph_name = "ggml_qnn_graph";
|
||||
std::string qnn_op_config_name = "ggml_qnn_op_config";
|
||||
|
@ -2164,7 +2126,6 @@ static void ggml_qnn_hanlde_op(const enum ggml_op ggmlop, const ggml_tensor * sr
|
|||
QNN_LOG_WARN("pls check why QNN tensor is null");
|
||||
return;
|
||||
}
|
||||
ctx = (struct ggml_backend_qnn_context *)g_qnn_backend->context;
|
||||
if (nullptr == ctx) {
|
||||
QNN_LOG_WARN("pls check why backend ctx is null");
|
||||
return;
|
||||
|
@ -2201,9 +2162,9 @@ static void ggml_qnn_hanlde_op(const enum ggml_op ggmlop, const ggml_tensor * sr
|
|||
dst->type, ggml_type_name(dst->type), dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0],
|
||||
dst->nb[1], dst->nb[2]);
|
||||
QNN_LOG_DEBUG("%d, %d, %d, %d", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
||||
QNN_LOG_DEBUG("tensor0 name %s", QNN_TENSOR_GET_NAME(tensor_0));
|
||||
QNN_LOG_DEBUG("tensor1 name %s", QNN_TENSOR_GET_NAME(tensor_1));
|
||||
QNN_LOG_DEBUG("tensor2 name %s", QNN_TENSOR_GET_NAME(tensor_2));
|
||||
QNN_LOG_DEBUG("tensor0 name %s", QNN_TENSOR_GET_NAME(*tensor_0));
|
||||
QNN_LOG_DEBUG("tensor1 name %s", QNN_TENSOR_GET_NAME(*tensor_1));
|
||||
QNN_LOG_DEBUG("tensor2 name %s", QNN_TENSOR_GET_NAME(*tensor_2));
|
||||
|
||||
QNN_VER_PTR(*tensor_0)->type = QNN_TENSOR_TYPE_APP_WRITE;
|
||||
QNN_VER_PTR(*tensor_1)->type = QNN_TENSOR_TYPE_APP_WRITE;
|
||||
|
@ -2349,153 +2310,154 @@ static void ggml_qnn_hanlde_op(const enum ggml_op ggmlop, const ggml_tensor * sr
|
|||
}
|
||||
|
||||
|
||||
static void ggml_qnn_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_repeat(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_get_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_get_rows(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_acc(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_div(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_div(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_gelu(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_silu(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_gelu_quick(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_tanh(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_tanh(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_relu(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_hardsigmoid(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_hardsigmoid(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_hardswish(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_hardswish(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_leaky_relu(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_sqr(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_norm(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_group_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_group_norm(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_concat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_concat(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_upscale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_upscale(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_pad(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_pad(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_rms_norm(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_cpy(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_qnn_cpy(src0, dst, nullptr);
|
||||
static void ggml_qnn_dup(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_qnn_cpy(ctx, src0, dst, nullptr);
|
||||
(void) src1;
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_mul_mat_id(const ggml_tensor * src0,
|
||||
static void ggml_qnn_mul_mat_id(ggml_backend_qnn_context * ctx,
|
||||
const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
@ -2504,35 +2466,35 @@ static void ggml_qnn_mul_mat_id(const ggml_tensor * src0,
|
|||
}
|
||||
|
||||
|
||||
static void ggml_qnn_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_scale(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_clamp(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_diag_mask_inf(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_soft_max(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_rope(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
|
@ -2541,21 +2503,21 @@ static void ggml_qnn_rope(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
|||
}
|
||||
|
||||
|
||||
static void ggml_qnn_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_pool2d(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_im2col(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
QNN_LOG_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_qnn_sum_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_sum_rows(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
|
@ -2563,7 +2525,7 @@ static void ggml_qnn_sum_rows(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
}
|
||||
|
||||
|
||||
static void ggml_qnn_argsort(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_argsort(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
QNN_LOG_DEBUG("call %s\n", __func__);
|
||||
|
||||
|
@ -2571,7 +2533,7 @@ static void ggml_qnn_argsort(const ggml_tensor * src0, const ggml_tensor * src1,
|
|||
}
|
||||
|
||||
|
||||
static void ggml_qnn_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_qnn_nop(ggml_backend_qnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
(void) src0;
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
|
@ -2581,7 +2543,7 @@ static void ggml_qnn_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
}
|
||||
|
||||
|
||||
bool ggml_qnn_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||
bool ggml_qnn_compute_forward(ggml_backend_qnn_context * ctx, struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||
ggml_qnn_func_t func = nullptr;
|
||||
ggml_qnn_func_common_t func_common = nullptr;
|
||||
|
||||
|
@ -2715,16 +2677,21 @@ bool ggml_qnn_compute_forward(struct ggml_compute_params * params, struct ggml_t
|
|||
}
|
||||
|
||||
if (nullptr != func)
|
||||
func(tensor->src[0], tensor->src[1], tensor);
|
||||
func(ctx, tensor->src[0], tensor->src[1], tensor);
|
||||
|
||||
if (nullptr != func_common)
|
||||
func_common(tensor->op, tensor->src[0], tensor->src[1], tensor);
|
||||
func_common(ctx, tensor->op, tensor->src[0], tensor->src[1], tensor);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
struct ggml_backend_qnn_buffer_context {
|
||||
ggml_backend_qnn_buffer_context(size_t device) :
|
||||
device(device),
|
||||
name(GGML_QNN_NAME + std::to_string(device)) {
|
||||
}
|
||||
|
||||
~ggml_backend_qnn_buffer_context() {
|
||||
if (buffer) {
|
||||
free(buffer);
|
||||
|
@ -2749,6 +2716,14 @@ struct ggml_backend_qnn_buffer_context {
|
|||
size_t buffer_size = 0;
|
||||
std::vector<void *> sub_buffers;
|
||||
std::vector<Qnn_Tensor_t *> qnn_tensors;
|
||||
size_t device;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
|
||||
struct ggml_backend_qnn_buffer_type_context {
|
||||
size_t device;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
|
||||
|
@ -2782,7 +2757,7 @@ GGML_CALL static void ggml_backend_qnn_buffer_init_tensor(ggml_backend_buffer_t
|
|||
|
||||
static int idx = 0;
|
||||
char tensor_name[GGML_MAX_NAME] = { 0 };
|
||||
snprintf(tensor_name, GGML_MAX_NAME, "tensor_%2d", idx++);
|
||||
snprintf(tensor_name, GGML_MAX_NAME, "tensor_%04d", idx++);
|
||||
|
||||
uint32_t dimensions[] = {(uint32_t) tensor->ne[0], (uint32_t) tensor->ne[1], (uint32_t) tensor->ne[2], (uint32_t) tensor->ne[3]};
|
||||
Qnn_DataType_t qnn_data_type = qnn_datatype_from_ggml_datatype(tensor->type);
|
||||
|
@ -2888,7 +2863,8 @@ static void * ggml_qnn_host_malloc(size_t n) {
|
|||
|
||||
|
||||
GGML_CALL static ggml_backend_buffer_t ggml_backend_qnn_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
ggml_backend_qnn_buffer_context * ctx = new ggml_backend_qnn_buffer_context;
|
||||
ggml_backend_qnn_buffer_type_context * buft_ctx = (ggml_backend_qnn_buffer_type_context *)buft->context;
|
||||
ggml_backend_qnn_buffer_context * ctx = new ggml_backend_qnn_buffer_context(buft_ctx->device);
|
||||
|
||||
const size_t size_page = sysconf(_SC_PAGESIZE);
|
||||
|
||||
|
@ -2901,7 +2877,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_qnn_buffer_type_alloc_buffer
|
|||
ctx->buffer = ggml_qnn_host_malloc(size_aligned);
|
||||
ctx->buffer_size = size_aligned;
|
||||
|
||||
ctx->backend_ctx = &g_qnn_mgr[g_current_device];
|
||||
ctx->backend_ctx = &g_qnn_mgr[buft_ctx->device];
|
||||
|
||||
if (nullptr == ctx->buffer) {
|
||||
QNN_LOG_WARN("%s: failed to allocate %.2f MiB\n", __func__, size / (1 << 20));
|
||||
|
@ -2968,7 +2944,6 @@ GGML_CALL static void ggml_backend_qnn_free(ggml_backend_t backend) {
|
|||
|
||||
if (g_qnn_mgr[ctx->device].backend != nullptr) {
|
||||
delete backend;
|
||||
g_qnn_backend = nullptr;
|
||||
g_qnn_mgr[ctx->device].backend = nullptr;
|
||||
}
|
||||
QNN_LOG_INFO("leave %s", __func__ );
|
||||
|
@ -2995,7 +2970,7 @@ GGML_CALL static ggml_status ggml_backend_qnn_graph_compute(ggml_backend_t backe
|
|||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
}
|
||||
bool ok = ggml_qnn_compute_forward(¶ms, node);
|
||||
bool ok = ggml_qnn_compute_forward(ctx, ¶ms, node);
|
||||
if (!ok) {
|
||||
QNN_LOG_DEBUG("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
|
@ -3017,9 +2992,9 @@ GGML_CALL static bool ggml_backend_qnn_supports_op(ggml_backend_t backend, const
|
|||
// new ggml backend(only using system memory: ggml_backend_xxx_buffer_is_host return true)
|
||||
// can following this style for mixed inference between CPU&GPU / CPU&NPU very easily
|
||||
GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend, const ggml_tensor * tensor) {
|
||||
GGML_UNUSED(backend);
|
||||
ggml_backend_qnn_context * ctx = (ggml_backend_qnn_context *) backend->context;
|
||||
|
||||
return ggml_qnn_compute_forward(nullptr, (ggml_tensor*)tensor);
|
||||
return ggml_qnn_compute_forward(ctx, nullptr, (ggml_tensor*)tensor);
|
||||
}
|
||||
|
||||
|
||||
|
@ -3104,14 +3079,20 @@ void ggml_backend_qnn_get_device_description(size_t dev_num, char * description,
|
|||
}
|
||||
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_qnn_buffer_type(size_t device_index) {
|
||||
if (device_index >= GGML_QNN_MAX_DEVICES) {
|
||||
ggml_backend_buffer_type_t ggml_backend_qnn_buffer_type(size_t device) {
|
||||
if (device >= GGML_QNN_MAX_DEVICES) {
|
||||
QNN_LOG_DEBUG("ggml_backend_qnn_buffer_type error: device_index:%d is out of range [0, %d]\n",
|
||||
device_index, GGML_QNN_MAX_DEVICES - 1);
|
||||
device, GGML_QNN_MAX_DEVICES - 1);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static struct ggml_backend_buffer_type ggml_backend_buffer_type_qnn = {
|
||||
static ggml_backend_buffer_type ggml_backend_qnn_buffer_types[GGML_QNN_MAX_DEVICES];
|
||||
|
||||
static bool ggml_backend_qnn_buffer_type_initialized = false;
|
||||
|
||||
if (!ggml_backend_qnn_buffer_type_initialized) {
|
||||
for (int i = 0; i < GGML_QNN_MAX_DEVICES; i++) {
|
||||
ggml_backend_qnn_buffer_types[i] = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_qnn_buffer_type_name,
|
||||
/* .alloc_buffer = */ ggml_backend_qnn_buffer_type_alloc_buffer,
|
||||
|
@ -3121,10 +3102,13 @@ ggml_backend_buffer_type_t ggml_backend_qnn_buffer_type(size_t device_index) {
|
|||
/* .supports_backend = */ ggml_backend_qnn_buffer_type_supports_backend,
|
||||
/* .is_host = */ ggml_backend_qnn_buffer_is_host
|
||||
},
|
||||
/* .context = */ nullptr,
|
||||
/* .context = */ new ggml_backend_qnn_buffer_type_context { device, GGML_QNN_NAME + std::to_string(device) },
|
||||
};
|
||||
}
|
||||
ggml_backend_qnn_buffer_type_initialized = true;
|
||||
}
|
||||
|
||||
return &ggml_backend_buffer_type_qnn;
|
||||
return &ggml_backend_qnn_buffer_types[device];
|
||||
}
|
||||
|
||||
|
||||
|
@ -3137,8 +3121,10 @@ ggml_backend_buffer_type_t ggml_backend_qnn_buffer_type(size_t device_index) {
|
|||
ggml_backend_t ggml_backend_qnn_init(size_t device, const char * qnn_lib_path) {
|
||||
int result = 0;
|
||||
|
||||
if (nullptr == qnn_lib_path)
|
||||
if (nullptr == qnn_lib_path) {
|
||||
QNN_LOG_ERROR("invalid qnn lib path\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
QNN_LOG_DEBUG("device %d", device);
|
||||
QNN_LOG_DEBUG("qnn_lib_path %s", qnn_lib_path);
|
||||
|
@ -3147,18 +3133,6 @@ ggml_backend_t ggml_backend_qnn_init(size_t device, const char * qnn_lib_path) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (nullptr != g_qnn_mgr[device].backend) {
|
||||
QNN_LOG_ERROR("qnn backend %d(%s) already loaded", device, get_qnn_backend_name(device));
|
||||
if (device == g_current_device) {
|
||||
g_qnn_backend = g_qnn_mgr[device].backend;
|
||||
QNN_LOG_INFO("re-use cached backend %d(%s)", device, get_qnn_backend_name(device));
|
||||
return g_qnn_mgr[device].backend;
|
||||
} else {
|
||||
QNN_LOG_INFO("delete previous backend %d(%s)", device, get_qnn_backend_name(device));
|
||||
ggml_backend_qnn_free(g_qnn_backend);
|
||||
}
|
||||
}
|
||||
|
||||
std::string path = qnn_lib_path;
|
||||
if (QNN_BACKEND_NPU == device) {
|
||||
if (0 == setenv("LD_LIBRARY_PATH",
|
||||
|
@ -3215,8 +3189,6 @@ ggml_backend_t ggml_backend_qnn_init(size_t device, const char * qnn_lib_path) {
|
|||
/* .context = */ &g_qnn_mgr[device]
|
||||
};
|
||||
g_qnn_mgr[device].backend = qnn_backend;
|
||||
g_qnn_backend = g_qnn_mgr[device].backend;
|
||||
g_current_device = device;
|
||||
|
||||
return qnn_backend;
|
||||
}
|
||||
|
|
|
@ -158,6 +158,7 @@ static void tensor_dump(const ggml_tensor * tensor, const char * name) {
|
|||
QNN_LOG_WARN("tensor is null");
|
||||
return;
|
||||
}
|
||||
|
||||
if (tensor->type == GGML_TYPE_I8) {
|
||||
for (int h = 0; h < tensor->ne[3]; h++) {
|
||||
for (int i = 0; i < tensor->ne[2]; i++) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue