refactoring: Moved the unified memory code in the correct location.

This commit is contained in:
matteo serva 2024-07-16 18:45:23 +02:00
parent 82fadbd792
commit f258b9273b

View file

@ -130,7 +130,16 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
}
return res;
#else
return cudaMalloc(ptr, size);
cudaError_t err;
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
{
err = cudaMallocManaged(ptr, size);
}
else
{
err = cudaMalloc(ptr, size);
}
return err;
#endif
}
@ -558,14 +567,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
void * dev_ptr;
cudaError_t err;
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
{
err = cudaMallocManaged(&dev_ptr, size);
}
else {
err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
}
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
if (err != cudaSuccess) {
// clear the error
cudaGetLastError();