generalize LLAMA_SPLIT_LAYER for all backends, do not expose device count and memory in llama.h

This commit is contained in:
slaren 2024-02-05 02:16:19 +01:00
parent c71316f825
commit daa6a9c303
2 changed files with 33 additions and 42 deletions

View file

@ -1367,6 +1367,33 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(int fallback_g
GGML_UNUSED(tensor_split);
}
static size_t llama_get_device_count() {
#if defined(GGML_USE_CUBLAS)
return ggml_backend_cuda_get_device_count();
#elif defined(GGML_USE_VULKAN)
return ggml_backend_vk_get_device_count();
#else
return 1;
#endif
}
static size_t llama_get_device_memory(int device) {
#if defined(GGML_USE_CUBLAS)
size_t total;
size_t free;
ggml_backend_cuda_get_device_memory(device, &total, &free);
return free;
#elif defined(GGML_USE_VULKAN)
size_t total;
size_t free;
ggml_backend_vk_get_device_memory(device, &total, &free);
return free;
#else
return 1;
GGML_UNUSED(device);
#endif
}
//
// globals
//
@ -3402,20 +3429,18 @@ static bool llm_load_tensors(
model.buft_layer[i] = llama_default_buffer_type_cpu(true);
}
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_VULKAN)
if (split_mode == LLAMA_SPLIT_LAYER) {
// calculate the split points
int device_count = llama_get_device_count();
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
std::vector<float> splits_vec(device_count);
float * splits = splits_vec.data();
std::vector<float> splits(device_count);
if (all_zero) {
// default split, by free memory
for (int i = 0; i < device_count; ++i) {
splits[i] = llama_get_default_device_split(i);
splits[i] = llama_get_device_memory(i);
}
} else {
std::copy(tensor_split, tensor_split + device_count, splits);
std::copy(tensor_split, tensor_split + device_count, splits.begin());
}
// sum and normalize the splits to get the split points
@ -3431,19 +3456,17 @@ static bool llm_load_tensors(
// assign the repeating layers to the devices according to the splits
int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
for (int64_t i = i_gpu_start; i < n_layer; ++i) {
int layer_gpu = std::upper_bound(splits, splits + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits;
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin();
model.buft_layer[i] = llama_default_buffer_type_offload(layer_gpu);
}
// assign the output layer
if (n_gpu_layers > n_layer) {
int layer_gpu = std::upper_bound(splits, splits + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits;
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits.begin();
model.buft_output = llama_default_buffer_type_offload(layer_gpu);
} else {
model.buft_output = llama_default_buffer_type_cpu(true);
}
} else
#endif
{
} else {
ggml_backend_buffer_type_t split_buft;
if (split_mode == LLAMA_SPLIT_ROW) {
split_buft = llama_default_buffer_type_split(main_gpu, tensor_split);
@ -10300,36 +10323,6 @@ size_t llama_max_devices(void) {
#endif
}
size_t llama_get_device_count(void) {
#if defined(GGML_USE_METAL)
return 1;
#elif defined(GGML_USE_CUBLAS)
return ggml_backend_cuda_get_device_count();
#elif defined(GGML_USE_SYCL)
return 1;
#elif defined(GGML_USE_VULKAN)
return ggml_backend_vk_get_device_count();
#else
return 0;
#endif
}
LLAMA_API size_t llama_get_default_device_split(int device) {
#if defined(GGML_USE_CUBLAS)
size_t total;
size_t free;
ggml_backend_cuda_get_device_memory(device, &total, &free);
return free;
#elif defined(GGML_USE_VULKAN)
size_t total;
size_t free;
ggml_backend_vk_get_device_memory(device, &total, &free);
return free;
#else
return 1;
#endif
}
bool llama_supports_mmap(void) {
return llama_mmap::SUPPORTED;
}

View file

@ -325,8 +325,6 @@ extern "C" {
LLAMA_API int64_t llama_time_us(void);
LLAMA_API size_t llama_max_devices(void);
LLAMA_API size_t llama_get_device_count(void);
LLAMA_API size_t llama_get_default_device_split(int device);
LLAMA_API bool llama_supports_mmap (void);
LLAMA_API bool llama_supports_mlock (void);