simplify code, more consistent style

This commit is contained in:
slaren 2024-05-26 18:47:42 +02:00
parent b2c0f7f303
commit a6a1abd98e
2 changed files with 4 additions and 11 deletions

View file

@ -119,19 +119,15 @@ int ggml_cuda_get_device() {
return id;
}
// ggml_cuda_host_malloc
static inline cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
#if defined(GGML_USE_HIPBLAS)
#if defined(GGML_HIP_UMA)
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device);
#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
auto res = hipMallocManaged(ptr, size);
if (res == hipSuccess) {
// if error we "need" to know why...
CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
}
return res;
#else
return hipMalloc(ptr, size);
#endif
#else
return cudaMalloc(ptr, size);
#endif

View file

@ -79,11 +79,8 @@
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#ifdef GGML_HIP_UMA
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
#else
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#endif
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync