refactoring: Moved the unified memory code in the correct location.
This commit is contained in:
parent
82fadbd792
commit
f258b9273b
1 changed files with 11 additions and 9 deletions
|
@ -130,7 +130,16 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
|||
}
|
||||
return res;
|
||||
#else
|
||||
return cudaMalloc(ptr, size);
|
||||
cudaError_t err;
|
||||
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
||||
{
|
||||
err = cudaMallocManaged(ptr, size);
|
||||
}
|
||||
else
|
||||
{
|
||||
err = cudaMalloc(ptr, size);
|
||||
}
|
||||
return err;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -558,14 +567,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
|
|||
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
|
||||
|
||||
void * dev_ptr;
|
||||
cudaError_t err;
|
||||
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
||||
{
|
||||
err = cudaMallocManaged(&dev_ptr, size);
|
||||
}
|
||||
else {
|
||||
err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
|
||||
}
|
||||
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
cudaGetLastError();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue