diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 9e1acd3f1..c959a63bd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -58,8 +58,8 @@ #define cudaGetDeviceProperties hipGetDeviceProperties #define cudaGetErrorString hipGetErrorString #define cudaGetLastError hipGetLastError -#define cudaMalloc hipMalloc -#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMalloc(ptr, size) hipMallocManaged(ptr, size, hipMemAttachGlobal) +#define cudaMallocHost(ptr, size) hipMallocHost(ptr, size) #define cudaMemcpy hipMemcpy #define cudaMemcpy2DAsync hipMemcpy2DAsync #define cudaMemcpyAsync hipMemcpyAsync