All memory pool operation are checked during init phase. For CUDA 12+ device properties checked.

This commit is contained in:
Oleksii Maryshchenko 2023-11-04 10:25:51 +01:00
parent 815bf1a2f6
commit 56e516240a

View file

@ -5849,16 +5849,43 @@ void ggml_init_cublas() {
cudaDeviceProp prop; cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
fprintf(stderr, " Device %d: %s, compute capability %d.%d", id, prop.name, prop.major, prop.minor); fprintf(stderr, " Device %d: %s, compute capability %d.%d", id, prop.name, prop.major, prop.minor);
#if defined(CUDA_USE_MEMORY_POOL) #if defined(CUDA_USE_MEMORY_POOL)
// configure memory pool bool support_mem_pool = true;
#if CUDART_VERSION >= 12000
support_mem_pool = (prop.memoryPoolsSupported == 1);
#endif
if (support_mem_pool) {
cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id); cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
if (err == cudaSuccess) { if (err == cudaSuccess) {
size_t treshold = UINT64_MAX; size_t treshold = UINT64_MAX;
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold)); err = (cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
if (err == cudaSuccess) {
fprintf(stderr, ", CUDA memory pool is supported\n"); fprintf(stderr, ", CUDA memory pool is supported\n");
} else { } else {
g_cudaMemPools[id] = nullptr; g_cudaMemPools[id] = nullptr;
fprintf(stderr, ", CUDA memory pool is not supported\n"); fprintf(stderr, ", CUDA memory pool is not supported (release threshold error)\n");
}
} else {
g_cudaMemPools[id] = nullptr;
fprintf(stderr, ", CUDA memory pool is not supported (cant load default pool)\n");
}
// test alloc/dealoc
if (err == cudaSuccess) {
void *testPtr;
size_t testSize = 1024;
err = cudaMallocFromPoolAsync(&testPtr, testSize, g_cudaMemPools[id], g_cudaStreams[id][0]);
if (err == cudaSuccess) {
err = cudaFreeAsync(testPtr, g_cudaStreams[id][0]);
if (err != cudaSuccess) {
g_cudaMemPools[id] = nullptr;
fprintf(stderr, ", CUDA memory pool is not supported (deallocation failed)\n");
}
} else {
g_cudaMemPools[id] = nullptr;
fprintf(stderr, ", CUDA memory pool is not supported (allocation failed)\n");
}
}
} }
#endif #endif
g_tensor_split[id] = total_vram; g_tensor_split[id] = total_vram;