Modify the code based on review comment

This commit is contained in:
huafengchun 2024-07-16 07:15:26 +00:00
parent 0da1e1fc19
commit f50f0905bc
4 changed files with 23 additions and 69 deletions

View file

@ -22,6 +22,9 @@
#pragma once
#define GGML_COMMON_DECL_C
#include "../src/ggml-common.h"
#include "ggml-backend.h"
#include "ggml.h"
@ -29,35 +32,11 @@
extern "C" {
#endif
/**
* @def GGML_CANN_NAME
* @brief Define for the name of the CANN backend.
*/
#define GGML_CANN_NAME "CANN"
/**
* @brief Maximum number of CANN devices supported.
*/
#define GGML_CANN_MAX_DEVICES 16
/**
* @brief Structure for QK4_0 data format.
*/
#define QK4_0 32
typedef struct {
uint16_t d; /**< Delta */
uint8_t qs[QK4_0 / 2]; /**< Nibbles / quants */
} block_q4_0;
/**
* @brief Structure for QK8_0 data format.
*/
#define QK8_0 32
typedef struct {
uint16_t d; /**< Delta */
int8_t qs[QK8_0]; /**< Quants */
} block_q8_0;
/**
* @brief Initializes the CANN backend for a specified device.
*
@ -133,16 +112,6 @@ GGML_API GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device,
size_t* free,
size_t* total);
/**
* @brief Initializes resources required by the CANN backend.
*/
void ggml_cann_backend_init(void);
/**
* @brief Frees resources used by the CANN backend.
*/
void ggml_cann_backend_free(void);
#ifdef __cplusplus
}
#endif

View file

@ -96,7 +96,7 @@ static ggml_cann_device_info ggml_cann_init() {
aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
if (err != ACL_SUCCESS) {
fprintf(stderr, "%s: failed to initialize " GGML_CANN_NAME ": %s\n",
fprintf(stderr, "%s: failed to initialize CANN: %s\n",
__func__, aclGetRecentErrMsg());
return info;
}
@ -464,7 +464,6 @@ struct ggml_backend_cann_buffer_context {
int32_t device; ///< The device ID associated with this buffer context.
void* dev_ptr =
nullptr; ///< Pointer to the device memory allocated for the buffer.
std::string name; ///< Name of the buffer context.
/**
* @brief Constructor to initialize the CANN buffer context.
@ -474,8 +473,7 @@ struct ggml_backend_cann_buffer_context {
*/
ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
: device(device),
dev_ptr(dev_ptr),
name(GGML_CANN_NAME + std::to_string(device)) {}
dev_ptr(dev_ptr) {}
/**
* @brief Destructor to free the device memory allocated for the buffer.
@ -495,9 +493,9 @@ struct ggml_backend_cann_buffer_context {
GGML_CALL static const char* ggml_backend_cann_buffer_get_name(
ggml_backend_buffer_t buffer) {
ggml_backend_cann_buffer_context* ctx =
(ggml_backend_cann_buffer_context*)buffer->context;
return ctx->name.c_str();
return "CANN";
GGML_UNUSED(buffer);
}
/**
@ -1004,10 +1002,9 @@ struct ggml_backend_cann_buffer_type_context {
*/
GGML_CALL static const char* ggml_backend_cann_buffer_type_name(
ggml_backend_buffer_type_t buft) {
ggml_backend_cann_buffer_type_context* ctx =
(ggml_backend_cann_buffer_type_context*)buft->context;
return "CANN";
return ctx->name.c_str();
GGML_UNUSED(buft);
}
/**
@ -1152,7 +1149,7 @@ ggml_backend_cann_buffer_type(int32_t device) {
/* .iface = */ ggml_backend_cann_buffer_type_interface,
/* .context = */
new ggml_backend_cann_buffer_type_context{
i, GGML_CANN_NAME + std::to_string(i)},
i, "CANN" + std::to_string(i)},
};
}
ggml_backend_cann_buffer_type_initialized = true;
@ -1344,6 +1341,12 @@ GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) {
(ggml_backend_cann_context*)backend->context;
ACL_CHECK(aclrtSynchronizeDevice());
ACL_CHECK(aclrtResetDevice(cann_ctx->device));
// finalize when last backend freed.
if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
ACL_CHECK(aclFinalize());
}
delete cann_ctx;
delete backend;
}
@ -1703,15 +1706,9 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
*/
GGML_CALL static bool ggml_backend_cann_supports_buft(
ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
if (ggml_backend_buft_is_cann(buft)) {
ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context;
ggml_backend_cann_buffer_type_context* buft_ctx =
(ggml_backend_cann_buffer_type_context*)buft->context;
return buft_ctx->device == cann_ctx->device;
}
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
return false;
GGML_UNUSED(backend);
}
/**
@ -1870,6 +1867,7 @@ static ggml_guid_t ggml_backend_cann_guid() {
}
GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) {
aclInit(nullptr);
if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
return nullptr;
@ -1945,19 +1943,14 @@ extern "C" GGML_CALL int ggml_backend_cann_reg_devices();
* @return int The number of CANN devices registered.
*/
GGML_CALL int ggml_backend_cann_reg_devices() {
aclInit(nullptr);
uint32_t device_count = ggml_backend_cann_get_device_count();
// initialization
for (uint32_t i = 0; i < device_count; i++) {
char name[128];
snprintf(name, sizeof(name), "%s%d", GGML_CANN_NAME, i);
snprintf(name, sizeof(name), "CANN%d", i);
ggml_backend_register(name, ggml_backend_reg_cann_init,
ggml_backend_cann_buffer_type(i),
(void*)(intptr_t)i);
}
return device_count;
}
void ggml_cann_backend_init(void) { ACL_CHECK(aclInit(nullptr)); }
void ggml_cann_backend_free(void) { ACL_CHECK(aclFinalize()); }

View file

@ -221,7 +221,7 @@ struct ggml_backend_cann_context {
* @param device Device ID.
*/
explicit ggml_backend_cann_context(int device)
: device(device), name(GGML_CANN_NAME + std::to_string(device)) {}
: device(device), name("CANN" + std::to_string(device)) {}
/**
* @brief Destructor for cleaning up resources.

View file

@ -18916,10 +18916,6 @@ void llama_backend_init(void) {
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}
#if defined(GGML_USE_CANN)
ggml_cann_backend_init();
#endif
}
void llama_numa_init(enum ggml_numa_strategy numa) {
@ -18929,10 +18925,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
}
void llama_backend_free(void) {
#if defined(GGML_USE_CANN)
ggml_cann_backend_free();
#endif
ggml_quantize_free();
}