some modifications after review

This commit is contained in:
shanshan shen 2024-11-26 07:09:55 +00:00
parent 58652e42c3
commit 1c79893ca2

View file

@ -471,11 +471,8 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
*/ */
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device( std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
int device) { int device) {
if (device == 0) {
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device)); return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
} }
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
}
// cann buffer // cann buffer
/** /**
@ -486,22 +483,21 @@ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
*/ */
struct ggml_backend_cann_buffer_context { struct ggml_backend_cann_buffer_context {
int32_t device; ///< The device ID associated with this buffer context. int32_t device; ///< The device ID associated with this buffer context.
ggml_cann_pool_alloc* alloc; ///< Pointer to the device memory allocated for the buffer. void* dev_ptr = nullptr;
/** /**
* @brief Constructor to initialize the CANN buffer context. * @brief Constructor to initialize the CANN buffer context.
* *
* @param device The device ID associated with this buffer context. * @param device The device ID associated with this buffer context.
* @param alloc Pointer to the device memory allocated for the buffer.
*/ */
ggml_backend_cann_buffer_context(int32_t device, ggml_cann_pool_alloc* alloc) ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
: device(device), : device(device),
alloc(alloc) {} dev_ptr(dev_ptr) {}
/** /**
* @brief Destructor to free the device memory allocated for the buffer. * @brief Destructor to free the device memory allocated for the buffer.
*/ */
~ggml_backend_cann_buffer_context() { delete alloc; } ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr));}
}; };
/** /**
@ -547,7 +543,7 @@ static void* ggml_backend_cann_buffer_get_base(
ggml_backend_buffer_t buffer) { ggml_backend_buffer_t buffer) {
ggml_backend_cann_buffer_context* ctx = ggml_backend_cann_buffer_context* ctx =
(ggml_backend_cann_buffer_context*)buffer->context; (ggml_backend_cann_buffer_context*)buffer->context;
return ctx->alloc->get(); return ctx->dev_ptr;
} }
/** /**
@ -954,7 +950,7 @@ static void ggml_backend_cann_buffer_clear(
(ggml_backend_cann_buffer_context*)buffer->context; (ggml_backend_cann_buffer_context*)buffer->context;
ggml_cann_set_device(ctx->device); ggml_cann_set_device(ctx->device);
ACL_CHECK(aclrtMemset(ctx->alloc->get(), buffer->size, value, buffer->size)); ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
} }
/** /**
@ -1016,13 +1012,25 @@ static const char* ggml_backend_cann_buffer_type_name(
static ggml_backend_buffer_t static ggml_backend_buffer_t
ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
size_t size) { size_t size) {
ggml_backend_cann_context* cann_ctx = ggml_backend_cann_buffer_type_context* buft_ctx =
(ggml_backend_cann_context*)buft->device->context; (ggml_backend_cann_buffer_type_context*)buft->context;
ggml_cann_pool_alloc* alloc = new ggml_cann_pool_alloc(cann_ctx->pool(), size); ggml_cann_set_device(buft_ctx->device);
size = std::max(size, (size_t)1);
void* dev_ptr;
aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
if (err != ACL_SUCCESS) {
GGML_LOG_ERROR(
"%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
__func__, size / 1024.0 / 1024.0, buft_ctx->device,
aclGetRecentErrMsg());
return nullptr;
}
ggml_backend_cann_buffer_context* ctx = ggml_backend_cann_buffer_context* ctx =
new ggml_backend_cann_buffer_context(cann_ctx->device, alloc); new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface, return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
ctx, size); ctx, size);