some modifications after review
This commit is contained in:
parent
58652e42c3
commit
1c79893ca2
1 changed files with 23 additions and 15 deletions
|
@ -471,10 +471,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
|
|||
*/
|
||||
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
|
||||
int device) {
|
||||
if (device == 0) {
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
||||
}
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
|
||||
}
|
||||
|
||||
// cann buffer
|
||||
|
@ -486,22 +483,21 @@ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
|
|||
*/
|
||||
struct ggml_backend_cann_buffer_context {
|
||||
int32_t device; ///< The device ID associated with this buffer context.
|
||||
ggml_cann_pool_alloc* alloc; ///< Pointer to the device memory allocated for the buffer.
|
||||
void* dev_ptr = nullptr;
|
||||
|
||||
/**
|
||||
* @brief Constructor to initialize the CANN buffer context.
|
||||
*
|
||||
* @param device The device ID associated with this buffer context.
|
||||
* @param alloc Pointer to the device memory allocated for the buffer.
|
||||
*/
|
||||
ggml_backend_cann_buffer_context(int32_t device, ggml_cann_pool_alloc* alloc)
|
||||
ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
|
||||
: device(device),
|
||||
alloc(alloc) {}
|
||||
dev_ptr(dev_ptr) {}
|
||||
|
||||
/**
|
||||
* @brief Destructor to free the device memory allocated for the buffer.
|
||||
*/
|
||||
~ggml_backend_cann_buffer_context() { delete alloc; }
|
||||
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr));}
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -547,7 +543,7 @@ static void* ggml_backend_cann_buffer_get_base(
|
|||
ggml_backend_buffer_t buffer) {
|
||||
ggml_backend_cann_buffer_context* ctx =
|
||||
(ggml_backend_cann_buffer_context*)buffer->context;
|
||||
return ctx->alloc->get();
|
||||
return ctx->dev_ptr;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -954,7 +950,7 @@ static void ggml_backend_cann_buffer_clear(
|
|||
(ggml_backend_cann_buffer_context*)buffer->context;
|
||||
|
||||
ggml_cann_set_device(ctx->device);
|
||||
ACL_CHECK(aclrtMemset(ctx->alloc->get(), buffer->size, value, buffer->size));
|
||||
ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1016,13 +1012,25 @@ static const char* ggml_backend_cann_buffer_type_name(
|
|||
static ggml_backend_buffer_t
|
||||
ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
||||
size_t size) {
|
||||
ggml_backend_cann_context* cann_ctx =
|
||||
(ggml_backend_cann_context*)buft->device->context;
|
||||
ggml_backend_cann_buffer_type_context* buft_ctx =
|
||||
(ggml_backend_cann_buffer_type_context*)buft->context;
|
||||
|
||||
ggml_cann_pool_alloc* alloc = new ggml_cann_pool_alloc(cann_ctx->pool(), size);
|
||||
ggml_cann_set_device(buft_ctx->device);
|
||||
|
||||
size = std::max(size, (size_t)1);
|
||||
|
||||
void* dev_ptr;
|
||||
aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
|
||||
if (err != ACL_SUCCESS) {
|
||||
GGML_LOG_ERROR(
|
||||
"%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
|
||||
__func__, size / 1024.0 / 1024.0, buft_ctx->device,
|
||||
aclGetRecentErrMsg());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_backend_cann_buffer_context* ctx =
|
||||
new ggml_backend_cann_buffer_context(cann_ctx->device, alloc);
|
||||
new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
|
||||
|
||||
return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
|
||||
ctx, size);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue