cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available)

This commit is contained in:
slaren 2024-01-07 00:33:51 +01:00
parent 2f2c36799d
commit 72b74f364b
2 changed files with 4 additions and 3 deletions

View file

@ -9568,6 +9568,10 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
// FIXME: this is not thread safe
if (device >= ggml_backend_cuda_get_device_count()) {
return nullptr;
}
static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES];
static bool ggml_backend_cuda_buffer_type_initialized = false;
@ -9793,7 +9797,6 @@ ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * ten
// FIXME: this is not thread safe
static std::map<std::array<float, GGML_CUDA_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split_arr = {};
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_CUDA_MAX_DEVICES, [](float x) { return x == 0.0f; });

View file

@ -47,8 +47,6 @@ GGML_API int ggml_backend_cuda_get_device_count(void);
GGML_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
GGML_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
#ifdef __cplusplus
}
#endif