some modifications after review

This commit is contained in:
shanshan shen 2024-11-26 07:09:55 +00:00
parent 58652e42c3
commit 1c79893ca2

View file

@ -471,10 +471,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
*/
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
int device) {
if (device == 0) {
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
}
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
}
// cann buffer
@ -486,22 +483,21 @@ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
*/
struct ggml_backend_cann_buffer_context {
int32_t device; ///< The device ID associated with this buffer context.
ggml_cann_pool_alloc* alloc; ///< Pointer to the device memory allocated for the buffer.
void* dev_ptr = nullptr;
/**
* @brief Constructor to initialize the CANN buffer context.
*
* @param device The device ID associated with this buffer context.
* @param alloc Pointer to the device memory allocated for the buffer.
*/
ggml_backend_cann_buffer_context(int32_t device, ggml_cann_pool_alloc* alloc)
ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
: device(device),
alloc(alloc) {}
dev_ptr(dev_ptr) {}
/**
* @brief Destructor to free the device memory allocated for the buffer.
*/
~ggml_backend_cann_buffer_context() { delete alloc; }
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr));}
};
/**
@ -547,7 +543,7 @@ static void* ggml_backend_cann_buffer_get_base(
ggml_backend_buffer_t buffer) {
ggml_backend_cann_buffer_context* ctx =
(ggml_backend_cann_buffer_context*)buffer->context;
return ctx->alloc->get();
return ctx->dev_ptr;
}
/**
@ -954,7 +950,7 @@ static void ggml_backend_cann_buffer_clear(
(ggml_backend_cann_buffer_context*)buffer->context;
ggml_cann_set_device(ctx->device);
ACL_CHECK(aclrtMemset(ctx->alloc->get(), buffer->size, value, buffer->size));
ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
}
/**
@ -1016,13 +1012,25 @@ static const char* ggml_backend_cann_buffer_type_name(
static ggml_backend_buffer_t
ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
size_t size) {
ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)buft->device->context;
ggml_backend_cann_buffer_type_context* buft_ctx =
(ggml_backend_cann_buffer_type_context*)buft->context;
ggml_cann_pool_alloc* alloc = new ggml_cann_pool_alloc(cann_ctx->pool(), size);
ggml_cann_set_device(buft_ctx->device);
size = std::max(size, (size_t)1);
void* dev_ptr;
aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
if (err != ACL_SUCCESS) {
GGML_LOG_ERROR(
"%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
__func__, size / 1024.0 / 1024.0, buft_ctx->device,
aclGetRecentErrMsg());
return nullptr;
}
ggml_backend_cann_buffer_context* ctx =
new ggml_backend_cann_buffer_context(cann_ctx->device, alloc);
new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
ctx, size);