Check CUDA memory support in device properties.
This commit is contained in:
parent
abb77e7319
commit
ce4df17e42
1 changed files with 13 additions and 8 deletions
21
ggml-cuda.cu
21
ggml-cuda.cu
|
@ -5844,7 +5844,19 @@ void ggml_init_cublas() {
|
||||||
for (int id = 0; id < g_device_count; ++id) {
|
for (int id = 0; id < g_device_count; ++id) {
|
||||||
cudaDeviceProp prop;
|
cudaDeviceProp prop;
|
||||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
|
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
|
||||||
fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
|
fprintf(stderr, " Device %d: %s, compute capability %d.%d", id, prop.name, prop.major, prop.minor);
|
||||||
|
|
||||||
|
// configure memory pool
|
||||||
|
if (prop.memoryPoolsSupported == 1) {
|
||||||
|
cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
|
||||||
|
if (err == cudaSuccess) {
|
||||||
|
size_t treshold = UINT64_MAX;
|
||||||
|
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
|
||||||
|
fprintf(stderr, ", CUDA memory pool is supported\n");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, ", CUDA memory pool is not supported\n");
|
||||||
|
}
|
||||||
|
|
||||||
g_tensor_split[id] = total_vram;
|
g_tensor_split[id] = total_vram;
|
||||||
total_vram += prop.totalGlobalMem;
|
total_vram += prop.totalGlobalMem;
|
||||||
|
@ -5869,13 +5881,6 @@ void ggml_init_cublas() {
|
||||||
// create cublas handle
|
// create cublas handle
|
||||||
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
|
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
|
||||||
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
|
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
|
||||||
|
|
||||||
// configure memory pool
|
|
||||||
cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
|
|
||||||
if (err == cudaSuccess) {
|
|
||||||
size_t treshold = UINT64_MAX;
|
|
||||||
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// configure logging to stdout
|
// configure logging to stdout
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue