some modifications after review
This commit is contained in:
parent
58652e42c3
commit
1c79893ca2
1 changed files with 23 additions and 15 deletions
|
@ -471,10 +471,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
|
||||||
*/
|
*/
|
||||||
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
|
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
|
||||||
int device) {
|
int device) {
|
||||||
if (device == 0) {
|
|
||||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
||||||
}
|
|
||||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cann buffer
|
// cann buffer
|
||||||
|
@ -486,22 +483,21 @@ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
|
||||||
*/
|
*/
|
||||||
struct ggml_backend_cann_buffer_context {
|
struct ggml_backend_cann_buffer_context {
|
||||||
int32_t device; ///< The device ID associated with this buffer context.
|
int32_t device; ///< The device ID associated with this buffer context.
|
||||||
ggml_cann_pool_alloc* alloc; ///< Pointer to the device memory allocated for the buffer.
|
void* dev_ptr = nullptr;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Constructor to initialize the CANN buffer context.
|
* @brief Constructor to initialize the CANN buffer context.
|
||||||
*
|
*
|
||||||
* @param device The device ID associated with this buffer context.
|
* @param device The device ID associated with this buffer context.
|
||||||
* @param alloc Pointer to the device memory allocated for the buffer.
|
|
||||||
*/
|
*/
|
||||||
ggml_backend_cann_buffer_context(int32_t device, ggml_cann_pool_alloc* alloc)
|
ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
|
||||||
: device(device),
|
: device(device),
|
||||||
alloc(alloc) {}
|
dev_ptr(dev_ptr) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Destructor to free the device memory allocated for the buffer.
|
* @brief Destructor to free the device memory allocated for the buffer.
|
||||||
*/
|
*/
|
||||||
~ggml_backend_cann_buffer_context() { delete alloc; }
|
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr));}
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -547,7 +543,7 @@ static void* ggml_backend_cann_buffer_get_base(
|
||||||
ggml_backend_buffer_t buffer) {
|
ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_cann_buffer_context* ctx =
|
ggml_backend_cann_buffer_context* ctx =
|
||||||
(ggml_backend_cann_buffer_context*)buffer->context;
|
(ggml_backend_cann_buffer_context*)buffer->context;
|
||||||
return ctx->alloc->get();
|
return ctx->dev_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -954,7 +950,7 @@ static void ggml_backend_cann_buffer_clear(
|
||||||
(ggml_backend_cann_buffer_context*)buffer->context;
|
(ggml_backend_cann_buffer_context*)buffer->context;
|
||||||
|
|
||||||
ggml_cann_set_device(ctx->device);
|
ggml_cann_set_device(ctx->device);
|
||||||
ACL_CHECK(aclrtMemset(ctx->alloc->get(), buffer->size, value, buffer->size));
|
ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1016,13 +1012,25 @@ static const char* ggml_backend_cann_buffer_type_name(
|
||||||
static ggml_backend_buffer_t
|
static ggml_backend_buffer_t
|
||||||
ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
ggml_backend_cann_context* cann_ctx =
|
ggml_backend_cann_buffer_type_context* buft_ctx =
|
||||||
(ggml_backend_cann_context*)buft->device->context;
|
(ggml_backend_cann_buffer_type_context*)buft->context;
|
||||||
|
|
||||||
ggml_cann_pool_alloc* alloc = new ggml_cann_pool_alloc(cann_ctx->pool(), size);
|
ggml_cann_set_device(buft_ctx->device);
|
||||||
|
|
||||||
|
size = std::max(size, (size_t)1);
|
||||||
|
|
||||||
|
void* dev_ptr;
|
||||||
|
aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
|
||||||
|
if (err != ACL_SUCCESS) {
|
||||||
|
GGML_LOG_ERROR(
|
||||||
|
"%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
|
||||||
|
__func__, size / 1024.0 / 1024.0, buft_ctx->device,
|
||||||
|
aclGetRecentErrMsg());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_backend_cann_buffer_context* ctx =
|
ggml_backend_cann_buffer_context* ctx =
|
||||||
new ggml_backend_cann_buffer_context(cann_ctx->device, alloc);
|
new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
|
||||||
|
|
||||||
return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
|
return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
|
||||||
ctx, size);
|
ctx, size);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue