Modify the code based on review comment
This commit is contained in:
parent
0da1e1fc19
commit
f50f0905bc
4 changed files with 23 additions and 69 deletions
|
@ -22,6 +22,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
|
||||
#include "../src/ggml-common.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml.h"
|
||||
|
||||
|
@ -29,35 +32,11 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @def GGML_CANN_NAME
|
||||
* @brief Define for the name of the CANN backend.
|
||||
*/
|
||||
#define GGML_CANN_NAME "CANN"
|
||||
|
||||
/**
|
||||
* @brief Maximum number of CANN devices supported.
|
||||
*/
|
||||
#define GGML_CANN_MAX_DEVICES 16
|
||||
|
||||
/**
|
||||
* @brief Structure for QK4_0 data format.
|
||||
*/
|
||||
#define QK4_0 32
|
||||
typedef struct {
|
||||
uint16_t d; /**< Delta */
|
||||
uint8_t qs[QK4_0 / 2]; /**< Nibbles / quants */
|
||||
} block_q4_0;
|
||||
|
||||
/**
|
||||
* @brief Structure for QK8_0 data format.
|
||||
*/
|
||||
#define QK8_0 32
|
||||
typedef struct {
|
||||
uint16_t d; /**< Delta */
|
||||
int8_t qs[QK8_0]; /**< Quants */
|
||||
} block_q8_0;
|
||||
|
||||
/**
|
||||
* @brief Initializes the CANN backend for a specified device.
|
||||
*
|
||||
|
@ -133,16 +112,6 @@ GGML_API GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device,
|
|||
size_t* free,
|
||||
size_t* total);
|
||||
|
||||
/**
|
||||
* @brief Initializes resources required by the CANN backend.
|
||||
*/
|
||||
void ggml_cann_backend_init(void);
|
||||
|
||||
/**
|
||||
* @brief Frees resources used by the CANN backend.
|
||||
*/
|
||||
void ggml_cann_backend_free(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -96,7 +96,7 @@ static ggml_cann_device_info ggml_cann_init() {
|
|||
aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
|
||||
|
||||
if (err != ACL_SUCCESS) {
|
||||
fprintf(stderr, "%s: failed to initialize " GGML_CANN_NAME ": %s\n",
|
||||
fprintf(stderr, "%s: failed to initialize CANN: %s\n",
|
||||
__func__, aclGetRecentErrMsg());
|
||||
return info;
|
||||
}
|
||||
|
@ -464,7 +464,6 @@ struct ggml_backend_cann_buffer_context {
|
|||
int32_t device; ///< The device ID associated with this buffer context.
|
||||
void* dev_ptr =
|
||||
nullptr; ///< Pointer to the device memory allocated for the buffer.
|
||||
std::string name; ///< Name of the buffer context.
|
||||
|
||||
/**
|
||||
* @brief Constructor to initialize the CANN buffer context.
|
||||
|
@ -474,8 +473,7 @@ struct ggml_backend_cann_buffer_context {
|
|||
*/
|
||||
ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
|
||||
: device(device),
|
||||
dev_ptr(dev_ptr),
|
||||
name(GGML_CANN_NAME + std::to_string(device)) {}
|
||||
dev_ptr(dev_ptr) {}
|
||||
|
||||
/**
|
||||
* @brief Destructor to free the device memory allocated for the buffer.
|
||||
|
@ -495,9 +493,9 @@ struct ggml_backend_cann_buffer_context {
|
|||
|
||||
GGML_CALL static const char* ggml_backend_cann_buffer_get_name(
|
||||
ggml_backend_buffer_t buffer) {
|
||||
ggml_backend_cann_buffer_context* ctx =
|
||||
(ggml_backend_cann_buffer_context*)buffer->context;
|
||||
return ctx->name.c_str();
|
||||
return "CANN";
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1004,10 +1002,9 @@ struct ggml_backend_cann_buffer_type_context {
|
|||
*/
|
||||
GGML_CALL static const char* ggml_backend_cann_buffer_type_name(
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
ggml_backend_cann_buffer_type_context* ctx =
|
||||
(ggml_backend_cann_buffer_type_context*)buft->context;
|
||||
return "CANN";
|
||||
|
||||
return ctx->name.c_str();
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1151,8 +1148,8 @@ ggml_backend_cann_buffer_type(int32_t device) {
|
|||
ggml_backend_cann_buffer_types[i] = {
|
||||
/* .iface = */ ggml_backend_cann_buffer_type_interface,
|
||||
/* .context = */
|
||||
new ggml_backend_cann_buffer_type_context{
|
||||
i, GGML_CANN_NAME + std::to_string(i)},
|
||||
new ggml_backend_cann_buffer_type_context{
|
||||
i, "CANN" + std::to_string(i)},
|
||||
};
|
||||
}
|
||||
ggml_backend_cann_buffer_type_initialized = true;
|
||||
|
@ -1344,6 +1341,12 @@ GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) {
|
|||
(ggml_backend_cann_context*)backend->context;
|
||||
ACL_CHECK(aclrtSynchronizeDevice());
|
||||
ACL_CHECK(aclrtResetDevice(cann_ctx->device));
|
||||
|
||||
// finalize when last backend freed.
|
||||
if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
|
||||
ACL_CHECK(aclFinalize());
|
||||
}
|
||||
|
||||
delete cann_ctx;
|
||||
delete backend;
|
||||
}
|
||||
|
@ -1703,15 +1706,9 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
|
|||
*/
|
||||
GGML_CALL static bool ggml_backend_cann_supports_buft(
|
||||
ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
||||
if (ggml_backend_buft_is_cann(buft)) {
|
||||
ggml_backend_cann_context* cann_ctx =
|
||||
(ggml_backend_cann_context*)backend->context;
|
||||
ggml_backend_cann_buffer_type_context* buft_ctx =
|
||||
(ggml_backend_cann_buffer_type_context*)buft->context;
|
||||
return buft_ctx->device == cann_ctx->device;
|
||||
}
|
||||
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
|
||||
|
||||
return false;
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1870,6 +1867,7 @@ static ggml_guid_t ggml_backend_cann_guid() {
|
|||
}
|
||||
|
||||
GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) {
|
||||
aclInit(nullptr);
|
||||
if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
|
||||
fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
|
||||
return nullptr;
|
||||
|
@ -1945,19 +1943,14 @@ extern "C" GGML_CALL int ggml_backend_cann_reg_devices();
|
|||
* @return int The number of CANN devices registered.
|
||||
*/
|
||||
GGML_CALL int ggml_backend_cann_reg_devices() {
|
||||
aclInit(nullptr);
|
||||
uint32_t device_count = ggml_backend_cann_get_device_count();
|
||||
// initialization
|
||||
for (uint32_t i = 0; i < device_count; i++) {
|
||||
char name[128];
|
||||
snprintf(name, sizeof(name), "%s%d", GGML_CANN_NAME, i);
|
||||
snprintf(name, sizeof(name), "CANN%d", i);
|
||||
ggml_backend_register(name, ggml_backend_reg_cann_init,
|
||||
ggml_backend_cann_buffer_type(i),
|
||||
(void*)(intptr_t)i);
|
||||
}
|
||||
return device_count;
|
||||
}
|
||||
|
||||
void ggml_cann_backend_init(void) { ACL_CHECK(aclInit(nullptr)); }
|
||||
|
||||
void ggml_cann_backend_free(void) { ACL_CHECK(aclFinalize()); }
|
||||
|
|
|
@ -221,7 +221,7 @@ struct ggml_backend_cann_context {
|
|||
* @param device Device ID.
|
||||
*/
|
||||
explicit ggml_backend_cann_context(int device)
|
||||
: device(device), name(GGML_CANN_NAME + std::to_string(device)) {}
|
||||
: device(device), name("CANN" + std::to_string(device)) {}
|
||||
|
||||
/**
|
||||
* @brief Destructor for cleaning up resources.
|
||||
|
|
|
@ -18916,10 +18916,6 @@ void llama_backend_init(void) {
|
|||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
#if defined(GGML_USE_CANN)
|
||||
ggml_cann_backend_init();
|
||||
#endif
|
||||
}
|
||||
|
||||
void llama_numa_init(enum ggml_numa_strategy numa) {
|
||||
|
@ -18929,10 +18925,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
|
|||
}
|
||||
|
||||
void llama_backend_free(void) {
|
||||
#if defined(GGML_USE_CANN)
|
||||
ggml_cann_backend_free();
|
||||
#endif
|
||||
|
||||
ggml_quantize_free();
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue