Multi GPU memory pool access + Check memory pool support of multiple GPUs and main GPU.

This commit is contained in:
Oleksii Maryshchenko 2023-11-04 17:29:08 +01:00
parent 56e516240a
commit 81931b2ea7

View file

@ -503,6 +503,7 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
static void * g_scratch_buffer = nullptr;
static size_t g_scratch_size = 0; // disabled by default
static size_t g_scratch_offset = 0;
static bool g_cudaMutliGpuMemPoolSupported = true;
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
@ -5813,7 +5814,7 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
if (g_cudaMemPools[id] == nullptr) {
if (g_cudaMemPools[id] == nullptr || !g_cudaMutliGpuMemPoolSupported) {
return ggml_cuda_pool_free(ptr, actual_size);
}
CUDA_CHECK(cudaFreeAsync(ptr, stream));
@ -5896,6 +5897,49 @@ void ggml_init_cublas() {
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
#if defined(CUDA_USE_MEMORY_POOL)
if (g_device_count > 1) {
// give access to devices memory pools
if (g_cudaMemPools[g_main_device] != nullptr) {
cudaMemPool_t main_device_pool;
cudaMemAccessDesc desc_main_device = {};
desc_main_device.location.type = cudaMemLocationTypeDevice;
desc_main_device.location.id = g_main_device;
desc_main_device.flags = cudaMemAccessFlagsProtReadWrite;
CUDA_CHECK(cudaDeviceGetDefaultMemPool(&main_device_pool, g_main_device));
for (int id = 0; id < g_device_count; ++id) {
if (id == g_main_device) continue;
if (g_cudaMemPools[id] == nullptr) {
fprintf(stderr,
"Warning: Device %d doesnt support CUDA memory pool, skipping pool access config\n",
id);
}
cudaMemAccessDesc desc_device = {};
desc_device.location.type = cudaMemLocationTypeDevice;
desc_device.location.id = id;
desc_device.flags = cudaMemAccessFlagsProtReadWrite;
cudaError_t err = cudaMemPoolSetAccess(main_device_pool, &desc_device, 1 /* numDescs */);
if (err != cudaSuccess) {
fprintf(stderr, "Cant give access for main device memory pool to device %d\n", id);
}
cudaMemPool_t mempool;
CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, id));
err = cudaMemPoolSetAccess(mempool, &desc_main_device, 1 /* numDescs */);
if (err != cudaSuccess) {
fprintf(stderr, "Cant give access for device %d memory pool to main device \n", id);
}
}
} else {
fprintf(stderr,
"WARNING: Your main GPU device doesnt support CUDA memory pools. Using custom memory pool implementation.\n");
g_cudaMutliGpuMemPoolSupported = false;
}
}
#endif
for (int id = 0; id < g_device_count; ++id) {
g_tensor_split[id] /= total_vram;
}
@ -6410,7 +6454,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
if (src1_convert_f16) {
src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
src1_dfloat = (half *) ggml_cuda_pool_malloc_async(ne00*sizeof(half), &ash, g_main_device, stream);
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
ne00, 1, sizeof(float), 0, 0,
ne00, 1, sizeof(half), 0, 0, stream);
@ -6811,7 +6855,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
if (src0_on_device) {
src0_ddf = (float *) src0_extra->data_device[g_main_device];
} else {
src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf);
src0_ddf = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_asf, g_main_device, main_stream);
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
}
@ -6819,14 +6863,14 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
if (src1_on_device) {
src1_ddf = (float *) src1_extra->data_device[g_main_device];
} else {
src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf);
src1_ddf = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf, g_main_device, main_stream);
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
}
}
if (dst_on_device) {
dst_ddf = (float *) dst_extra->data_device[g_main_device];
} else {
dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf);
dst_ddf = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(dst), &dst_asf, g_main_device, main_stream);
}
// do the computation
@ -6838,19 +6882,19 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
}
if (src0_asf > 0) {
ggml_cuda_pool_free(src0_ddf, src0_asf);
}
if (src1_asf > 0) {
ggml_cuda_pool_free(src1_ddf, src1_asf);
}
if (dst_asf > 0) {
ggml_cuda_pool_free(dst_ddf, dst_asf);
}
if (dst->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaDeviceSynchronize());
}
if (src0_asf > 0) {
ggml_cuda_pool_free_async(src0_ddf, src0_asf, g_main_device, main_stream);
}
if (src1_asf > 0) {
ggml_cuda_pool_free_async(src1_ddf, src1_asf, g_main_device, main_stream);
}
if (dst_asf > 0) {
ggml_cuda_pool_free_async(dst_ddf, dst_asf, g_main_device, main_stream);
}
}
static void ggml_cuda_set_peer_access(const int n_tokens) {