Modify the code based on review comment

This commit is contained in:
huafengchun 2024-07-16 07:15:26 +00:00
parent 0da1e1fc19
commit f50f0905bc
4 changed files with 23 additions and 69 deletions

View file

@ -22,6 +22,9 @@
#pragma once #pragma once
#define GGML_COMMON_DECL_C
#include "../src/ggml-common.h"
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml.h" #include "ggml.h"
@ -29,35 +32,11 @@
extern "C" { extern "C" {
#endif #endif
/**
* @def GGML_CANN_NAME
* @brief Define for the name of the CANN backend.
*/
#define GGML_CANN_NAME "CANN"
/** /**
* @brief Maximum number of CANN devices supported. * @brief Maximum number of CANN devices supported.
*/ */
#define GGML_CANN_MAX_DEVICES 16 #define GGML_CANN_MAX_DEVICES 16
/**
* @brief Structure for QK4_0 data format.
*/
#define QK4_0 32
typedef struct {
uint16_t d; /**< Delta */
uint8_t qs[QK4_0 / 2]; /**< Nibbles / quants */
} block_q4_0;
/**
* @brief Structure for QK8_0 data format.
*/
#define QK8_0 32
typedef struct {
uint16_t d; /**< Delta */
int8_t qs[QK8_0]; /**< Quants */
} block_q8_0;
/** /**
* @brief Initializes the CANN backend for a specified device. * @brief Initializes the CANN backend for a specified device.
* *
@ -133,16 +112,6 @@ GGML_API GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device,
size_t* free, size_t* free,
size_t* total); size_t* total);
/**
* @brief Initializes resources required by the CANN backend.
*/
void ggml_cann_backend_init(void);
/**
* @brief Frees resources used by the CANN backend.
*/
void ggml_cann_backend_free(void);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View file

@ -96,7 +96,7 @@ static ggml_cann_device_info ggml_cann_init() {
aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count); aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
if (err != ACL_SUCCESS) { if (err != ACL_SUCCESS) {
fprintf(stderr, "%s: failed to initialize " GGML_CANN_NAME ": %s\n", fprintf(stderr, "%s: failed to initialize CANN: %s\n",
__func__, aclGetRecentErrMsg()); __func__, aclGetRecentErrMsg());
return info; return info;
} }
@ -464,7 +464,6 @@ struct ggml_backend_cann_buffer_context {
int32_t device; ///< The device ID associated with this buffer context. int32_t device; ///< The device ID associated with this buffer context.
void* dev_ptr = void* dev_ptr =
nullptr; ///< Pointer to the device memory allocated for the buffer. nullptr; ///< Pointer to the device memory allocated for the buffer.
std::string name; ///< Name of the buffer context.
/** /**
* @brief Constructor to initialize the CANN buffer context. * @brief Constructor to initialize the CANN buffer context.
@ -474,8 +473,7 @@ struct ggml_backend_cann_buffer_context {
*/ */
ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr) ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
: device(device), : device(device),
dev_ptr(dev_ptr), dev_ptr(dev_ptr) {}
name(GGML_CANN_NAME + std::to_string(device)) {}
/** /**
* @brief Destructor to free the device memory allocated for the buffer. * @brief Destructor to free the device memory allocated for the buffer.
@ -495,9 +493,9 @@ struct ggml_backend_cann_buffer_context {
GGML_CALL static const char* ggml_backend_cann_buffer_get_name( GGML_CALL static const char* ggml_backend_cann_buffer_get_name(
ggml_backend_buffer_t buffer) { ggml_backend_buffer_t buffer) {
ggml_backend_cann_buffer_context* ctx = return "CANN";
(ggml_backend_cann_buffer_context*)buffer->context;
return ctx->name.c_str(); GGML_UNUSED(buffer);
} }
/** /**
@ -1004,10 +1002,9 @@ struct ggml_backend_cann_buffer_type_context {
*/ */
GGML_CALL static const char* ggml_backend_cann_buffer_type_name( GGML_CALL static const char* ggml_backend_cann_buffer_type_name(
ggml_backend_buffer_type_t buft) { ggml_backend_buffer_type_t buft) {
ggml_backend_cann_buffer_type_context* ctx = return "CANN";
(ggml_backend_cann_buffer_type_context*)buft->context;
return ctx->name.c_str(); GGML_UNUSED(buft);
} }
/** /**
@ -1152,7 +1149,7 @@ ggml_backend_cann_buffer_type(int32_t device) {
/* .iface = */ ggml_backend_cann_buffer_type_interface, /* .iface = */ ggml_backend_cann_buffer_type_interface,
/* .context = */ /* .context = */
new ggml_backend_cann_buffer_type_context{ new ggml_backend_cann_buffer_type_context{
i, GGML_CANN_NAME + std::to_string(i)}, i, "CANN" + std::to_string(i)},
}; };
} }
ggml_backend_cann_buffer_type_initialized = true; ggml_backend_cann_buffer_type_initialized = true;
@ -1344,6 +1341,12 @@ GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) {
(ggml_backend_cann_context*)backend->context; (ggml_backend_cann_context*)backend->context;
ACL_CHECK(aclrtSynchronizeDevice()); ACL_CHECK(aclrtSynchronizeDevice());
ACL_CHECK(aclrtResetDevice(cann_ctx->device)); ACL_CHECK(aclrtResetDevice(cann_ctx->device));
// finalize when last backend freed.
if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
ACL_CHECK(aclFinalize());
}
delete cann_ctx; delete cann_ctx;
delete backend; delete backend;
} }
@ -1703,15 +1706,9 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
*/ */
GGML_CALL static bool ggml_backend_cann_supports_buft( GGML_CALL static bool ggml_backend_cann_supports_buft(
ggml_backend_t backend, ggml_backend_buffer_type_t buft) { ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
if (ggml_backend_buft_is_cann(buft)) { return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context;
ggml_backend_cann_buffer_type_context* buft_ctx =
(ggml_backend_cann_buffer_type_context*)buft->context;
return buft_ctx->device == cann_ctx->device;
}
return false; GGML_UNUSED(backend);
} }
/** /**
@ -1870,6 +1867,7 @@ static ggml_guid_t ggml_backend_cann_guid() {
} }
GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) { GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) {
aclInit(nullptr);
if (device < 0 || device >= ggml_backend_cann_get_device_count()) { if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
fprintf(stderr, "%s: error: invalid device %d\n", __func__, device); fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
return nullptr; return nullptr;
@ -1945,19 +1943,14 @@ extern "C" GGML_CALL int ggml_backend_cann_reg_devices();
* @return int The number of CANN devices registered. * @return int The number of CANN devices registered.
*/ */
GGML_CALL int ggml_backend_cann_reg_devices() { GGML_CALL int ggml_backend_cann_reg_devices() {
aclInit(nullptr);
uint32_t device_count = ggml_backend_cann_get_device_count(); uint32_t device_count = ggml_backend_cann_get_device_count();
// initialization // initialization
for (uint32_t i = 0; i < device_count; i++) { for (uint32_t i = 0; i < device_count; i++) {
char name[128]; char name[128];
snprintf(name, sizeof(name), "%s%d", GGML_CANN_NAME, i); snprintf(name, sizeof(name), "CANN%d", i);
ggml_backend_register(name, ggml_backend_reg_cann_init, ggml_backend_register(name, ggml_backend_reg_cann_init,
ggml_backend_cann_buffer_type(i), ggml_backend_cann_buffer_type(i),
(void*)(intptr_t)i); (void*)(intptr_t)i);
} }
return device_count; return device_count;
} }
void ggml_cann_backend_init(void) { ACL_CHECK(aclInit(nullptr)); }
void ggml_cann_backend_free(void) { ACL_CHECK(aclFinalize()); }

View file

@ -221,7 +221,7 @@ struct ggml_backend_cann_context {
* @param device Device ID. * @param device Device ID.
*/ */
explicit ggml_backend_cann_context(int device) explicit ggml_backend_cann_context(int device)
: device(device), name(GGML_CANN_NAME + std::to_string(device)) {} : device(device), name("CANN" + std::to_string(device)) {}
/** /**
* @brief Destructor for cleaning up resources. * @brief Destructor for cleaning up resources.

View file

@ -18916,10 +18916,6 @@ void llama_backend_init(void) {
struct ggml_context * ctx = ggml_init(params); struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx); ggml_free(ctx);
} }
#if defined(GGML_USE_CANN)
ggml_cann_backend_init();
#endif
} }
void llama_numa_init(enum ggml_numa_strategy numa) { void llama_numa_init(enum ggml_numa_strategy numa) {
@ -18929,10 +18925,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
} }
void llama_backend_free(void) { void llama_backend_free(void) {
#if defined(GGML_USE_CANN)
ggml_cann_backend_free();
#endif
ggml_quantize_free(); ggml_quantize_free();
} }