Modify the code based on review comment
This commit is contained in:
parent
0da1e1fc19
commit
f50f0905bc
4 changed files with 23 additions and 69 deletions
|
@ -22,6 +22,9 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
|
||||||
|
#include "../src/ggml-common.h"
|
||||||
#include "ggml-backend.h"
|
#include "ggml-backend.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
|
@ -29,35 +32,11 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
|
||||||
* @def GGML_CANN_NAME
|
|
||||||
* @brief Define for the name of the CANN backend.
|
|
||||||
*/
|
|
||||||
#define GGML_CANN_NAME "CANN"
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Maximum number of CANN devices supported.
|
* @brief Maximum number of CANN devices supported.
|
||||||
*/
|
*/
|
||||||
#define GGML_CANN_MAX_DEVICES 16
|
#define GGML_CANN_MAX_DEVICES 16
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Structure for QK4_0 data format.
|
|
||||||
*/
|
|
||||||
#define QK4_0 32
|
|
||||||
typedef struct {
|
|
||||||
uint16_t d; /**< Delta */
|
|
||||||
uint8_t qs[QK4_0 / 2]; /**< Nibbles / quants */
|
|
||||||
} block_q4_0;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Structure for QK8_0 data format.
|
|
||||||
*/
|
|
||||||
#define QK8_0 32
|
|
||||||
typedef struct {
|
|
||||||
uint16_t d; /**< Delta */
|
|
||||||
int8_t qs[QK8_0]; /**< Quants */
|
|
||||||
} block_q8_0;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Initializes the CANN backend for a specified device.
|
* @brief Initializes the CANN backend for a specified device.
|
||||||
*
|
*
|
||||||
|
@ -133,16 +112,6 @@ GGML_API GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device,
|
||||||
size_t* free,
|
size_t* free,
|
||||||
size_t* total);
|
size_t* total);
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Initializes resources required by the CANN backend.
|
|
||||||
*/
|
|
||||||
void ggml_cann_backend_init(void);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Frees resources used by the CANN backend.
|
|
||||||
*/
|
|
||||||
void ggml_cann_backend_free(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -96,7 +96,7 @@ static ggml_cann_device_info ggml_cann_init() {
|
||||||
aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
|
aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
|
||||||
|
|
||||||
if (err != ACL_SUCCESS) {
|
if (err != ACL_SUCCESS) {
|
||||||
fprintf(stderr, "%s: failed to initialize " GGML_CANN_NAME ": %s\n",
|
fprintf(stderr, "%s: failed to initialize CANN: %s\n",
|
||||||
__func__, aclGetRecentErrMsg());
|
__func__, aclGetRecentErrMsg());
|
||||||
return info;
|
return info;
|
||||||
}
|
}
|
||||||
|
@ -464,7 +464,6 @@ struct ggml_backend_cann_buffer_context {
|
||||||
int32_t device; ///< The device ID associated with this buffer context.
|
int32_t device; ///< The device ID associated with this buffer context.
|
||||||
void* dev_ptr =
|
void* dev_ptr =
|
||||||
nullptr; ///< Pointer to the device memory allocated for the buffer.
|
nullptr; ///< Pointer to the device memory allocated for the buffer.
|
||||||
std::string name; ///< Name of the buffer context.
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Constructor to initialize the CANN buffer context.
|
* @brief Constructor to initialize the CANN buffer context.
|
||||||
|
@ -474,8 +473,7 @@ struct ggml_backend_cann_buffer_context {
|
||||||
*/
|
*/
|
||||||
ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
|
ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
|
||||||
: device(device),
|
: device(device),
|
||||||
dev_ptr(dev_ptr),
|
dev_ptr(dev_ptr) {}
|
||||||
name(GGML_CANN_NAME + std::to_string(device)) {}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Destructor to free the device memory allocated for the buffer.
|
* @brief Destructor to free the device memory allocated for the buffer.
|
||||||
|
@ -495,9 +493,9 @@ struct ggml_backend_cann_buffer_context {
|
||||||
|
|
||||||
GGML_CALL static const char* ggml_backend_cann_buffer_get_name(
|
GGML_CALL static const char* ggml_backend_cann_buffer_get_name(
|
||||||
ggml_backend_buffer_t buffer) {
|
ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_cann_buffer_context* ctx =
|
return "CANN";
|
||||||
(ggml_backend_cann_buffer_context*)buffer->context;
|
|
||||||
return ctx->name.c_str();
|
GGML_UNUSED(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1004,10 +1002,9 @@ struct ggml_backend_cann_buffer_type_context {
|
||||||
*/
|
*/
|
||||||
GGML_CALL static const char* ggml_backend_cann_buffer_type_name(
|
GGML_CALL static const char* ggml_backend_cann_buffer_type_name(
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
ggml_backend_cann_buffer_type_context* ctx =
|
return "CANN";
|
||||||
(ggml_backend_cann_buffer_type_context*)buft->context;
|
|
||||||
|
|
||||||
return ctx->name.c_str();
|
GGML_UNUSED(buft);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1152,7 +1149,7 @@ ggml_backend_cann_buffer_type(int32_t device) {
|
||||||
/* .iface = */ ggml_backend_cann_buffer_type_interface,
|
/* .iface = */ ggml_backend_cann_buffer_type_interface,
|
||||||
/* .context = */
|
/* .context = */
|
||||||
new ggml_backend_cann_buffer_type_context{
|
new ggml_backend_cann_buffer_type_context{
|
||||||
i, GGML_CANN_NAME + std::to_string(i)},
|
i, "CANN" + std::to_string(i)},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
ggml_backend_cann_buffer_type_initialized = true;
|
ggml_backend_cann_buffer_type_initialized = true;
|
||||||
|
@ -1344,6 +1341,12 @@ GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) {
|
||||||
(ggml_backend_cann_context*)backend->context;
|
(ggml_backend_cann_context*)backend->context;
|
||||||
ACL_CHECK(aclrtSynchronizeDevice());
|
ACL_CHECK(aclrtSynchronizeDevice());
|
||||||
ACL_CHECK(aclrtResetDevice(cann_ctx->device));
|
ACL_CHECK(aclrtResetDevice(cann_ctx->device));
|
||||||
|
|
||||||
|
// finalize when last backend freed.
|
||||||
|
if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
|
||||||
|
ACL_CHECK(aclFinalize());
|
||||||
|
}
|
||||||
|
|
||||||
delete cann_ctx;
|
delete cann_ctx;
|
||||||
delete backend;
|
delete backend;
|
||||||
}
|
}
|
||||||
|
@ -1703,15 +1706,9 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
|
||||||
*/
|
*/
|
||||||
GGML_CALL static bool ggml_backend_cann_supports_buft(
|
GGML_CALL static bool ggml_backend_cann_supports_buft(
|
||||||
ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
||||||
if (ggml_backend_buft_is_cann(buft)) {
|
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
|
||||||
ggml_backend_cann_context* cann_ctx =
|
|
||||||
(ggml_backend_cann_context*)backend->context;
|
|
||||||
ggml_backend_cann_buffer_type_context* buft_ctx =
|
|
||||||
(ggml_backend_cann_buffer_type_context*)buft->context;
|
|
||||||
return buft_ctx->device == cann_ctx->device;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
GGML_UNUSED(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1870,6 +1867,7 @@ static ggml_guid_t ggml_backend_cann_guid() {
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) {
|
GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) {
|
||||||
|
aclInit(nullptr);
|
||||||
if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
|
if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
|
||||||
fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
|
fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1945,19 +1943,14 @@ extern "C" GGML_CALL int ggml_backend_cann_reg_devices();
|
||||||
* @return int The number of CANN devices registered.
|
* @return int The number of CANN devices registered.
|
||||||
*/
|
*/
|
||||||
GGML_CALL int ggml_backend_cann_reg_devices() {
|
GGML_CALL int ggml_backend_cann_reg_devices() {
|
||||||
aclInit(nullptr);
|
|
||||||
uint32_t device_count = ggml_backend_cann_get_device_count();
|
uint32_t device_count = ggml_backend_cann_get_device_count();
|
||||||
// initialization
|
// initialization
|
||||||
for (uint32_t i = 0; i < device_count; i++) {
|
for (uint32_t i = 0; i < device_count; i++) {
|
||||||
char name[128];
|
char name[128];
|
||||||
snprintf(name, sizeof(name), "%s%d", GGML_CANN_NAME, i);
|
snprintf(name, sizeof(name), "CANN%d", i);
|
||||||
ggml_backend_register(name, ggml_backend_reg_cann_init,
|
ggml_backend_register(name, ggml_backend_reg_cann_init,
|
||||||
ggml_backend_cann_buffer_type(i),
|
ggml_backend_cann_buffer_type(i),
|
||||||
(void*)(intptr_t)i);
|
(void*)(intptr_t)i);
|
||||||
}
|
}
|
||||||
return device_count;
|
return device_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cann_backend_init(void) { ACL_CHECK(aclInit(nullptr)); }
|
|
||||||
|
|
||||||
void ggml_cann_backend_free(void) { ACL_CHECK(aclFinalize()); }
|
|
||||||
|
|
|
@ -221,7 +221,7 @@ struct ggml_backend_cann_context {
|
||||||
* @param device Device ID.
|
* @param device Device ID.
|
||||||
*/
|
*/
|
||||||
explicit ggml_backend_cann_context(int device)
|
explicit ggml_backend_cann_context(int device)
|
||||||
: device(device), name(GGML_CANN_NAME + std::to_string(device)) {}
|
: device(device), name("CANN" + std::to_string(device)) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Destructor for cleaning up resources.
|
* @brief Destructor for cleaning up resources.
|
||||||
|
|
|
@ -18916,10 +18916,6 @@ void llama_backend_init(void) {
|
||||||
struct ggml_context * ctx = ggml_init(params);
|
struct ggml_context * ctx = ggml_init(params);
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_USE_CANN)
|
|
||||||
ggml_cann_backend_init();
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_numa_init(enum ggml_numa_strategy numa) {
|
void llama_numa_init(enum ggml_numa_strategy numa) {
|
||||||
|
@ -18929,10 +18925,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_backend_free(void) {
|
void llama_backend_free(void) {
|
||||||
#if defined(GGML_USE_CANN)
|
|
||||||
ggml_cann_backend_free();
|
|
||||||
#endif
|
|
||||||
|
|
||||||
ggml_quantize_free();
|
ggml_quantize_free();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue