diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 531e87c7a..86d3fb1ff 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -471,10 +471,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { */ std::unique_ptr ggml_backend_cann_context::new_pool_for_device( int device) { - if (device == 0) { - return std::unique_ptr(new ggml_cann_pool_vmm(device)); - } - return std::unique_ptr(new ggml_cann_pool_leg(device)); + return std::unique_ptr(new ggml_cann_pool_vmm(device)); } // cann buffer @@ -486,22 +483,21 @@ std::unique_ptr ggml_backend_cann_context::new_pool_for_device( */ struct ggml_backend_cann_buffer_context { int32_t device; ///< The device ID associated with this buffer context. - ggml_cann_pool_alloc* alloc; ///< Pointer to the device memory allocated for the buffer. + void* dev_ptr = nullptr; /** * @brief Constructor to initialize the CANN buffer context. * * @param device The device ID associated with this buffer context. - * @param alloc Pointer to the device memory allocated for the buffer. */ - ggml_backend_cann_buffer_context(int32_t device, ggml_cann_pool_alloc* alloc) + ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr) : device(device), - alloc(alloc) {} + dev_ptr(dev_ptr) {} /** * @brief Destructor to free the device memory allocated for the buffer. */ - ~ggml_backend_cann_buffer_context() { delete alloc; } + ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr));} }; /** @@ -547,7 +543,7 @@ static void* ggml_backend_cann_buffer_get_base( ggml_backend_buffer_t buffer) { ggml_backend_cann_buffer_context* ctx = (ggml_backend_cann_buffer_context*)buffer->context; - return ctx->alloc->get(); + return ctx->dev_ptr; } /** @@ -954,7 +950,7 @@ static void ggml_backend_cann_buffer_clear( (ggml_backend_cann_buffer_context*)buffer->context; ggml_cann_set_device(ctx->device); - ACL_CHECK(aclrtMemset(ctx->alloc->get(), buffer->size, value, buffer->size)); + ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size)); } /** @@ -1016,13 +1012,25 @@ static const char* ggml_backend_cann_buffer_type_name( static ggml_backend_buffer_t ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)buft->device->context; + ggml_backend_cann_buffer_type_context* buft_ctx = + (ggml_backend_cann_buffer_type_context*)buft->context; - ggml_cann_pool_alloc* alloc = new ggml_cann_pool_alloc(cann_ctx->pool(), size); + ggml_cann_set_device(buft_ctx->device); + + size = std::max(size, (size_t)1); + + void* dev_ptr; + aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); + if (err != ACL_SUCCESS) { + GGML_LOG_ERROR( + "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n", + __func__, size / 1024.0 / 1024.0, buft_ctx->device, + aclGetRecentErrMsg()); + return nullptr; + } ggml_backend_cann_buffer_context* ctx = - new ggml_backend_cann_buffer_context(cann_ctx->device, alloc); + new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr); return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface, ctx, size);