This commit is contained in:
Xuan Son Nguyen 2025-01-16 12:44:21 +01:00
parent 1782462bd4
commit 49822bab15
3 changed files with 40 additions and 39 deletions

View file

@ -64,6 +64,33 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
} }
} }
// return a list of splits for a given path
// for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
static std::vector<std::string> llama_get_list_splits(const std::string & path, const int idx, const int n_split) {
std::vector<std::string> paths;
std::string split_prefix;
std::vector<char> buf(llama_path_max(), 0);
{
int ret = llama_split_prefix(buf.data(), buf.size(), path.c_str(), idx, n_split);
if (!ret) {
throw std::runtime_error(format("invalid split file name: %s", path.c_str()));
}
split_prefix = std::string(buf.data(), ret);
}
if (split_prefix.empty()) {
throw std::runtime_error(format("invalid split file: %s", path.c_str()));
}
for (int idx = 0; idx < n_split; ++idx) {
int ret = llama_split_path(buf.data(), buf.size(), split_prefix.c_str(), idx, n_split);
paths.push_back(std::string(buf.data(), ret));
}
return paths;
}
namespace GGUFMeta { namespace GGUFMeta {
template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int64_t)> template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int64_t)>
struct GKV_Base_Type { struct GKV_Base_Type {
@ -466,16 +493,7 @@ llama_model_loader::llama_model_loader(
// Load additional GGML contexts // Load additional GGML contexts
if (n_split > 1) { if (n_split > 1) {
// generate list of splits if needed // make sure the main file is loaded first
if (splits.empty()) {
splits = llama_get_list_splits(fname, n_split);
}
// in case user give a custom list of splits, check if it matches the expected number
if (n_split != (uint16_t)splits.size()) {
throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split));
}
uint16_t idx = 0; uint16_t idx = 0;
const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO);
get_key(kv_split_no, idx); get_key(kv_split_no, idx);
@ -483,10 +501,21 @@ llama_model_loader::llama_model_loader(
throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str()));
} }
// generate list of splits if needed
if (splits.empty()) {
splits = llama_get_list_splits(fname, idx, n_split);
}
// in case user give a custom list of splits, check if it matches the expected number
if (n_split != (uint16_t)splits.size()) {
throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split));
}
if (trace > 0) { if (trace > 0) {
LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
} }
// load other splits
for (idx = 1; idx < n_split; idx++) { for (idx = 1; idx < n_split; idx++) {
const char * fname_split = splits[idx].c_str(); const char * fname_split = splits[idx].c_str();
@ -1093,28 +1122,3 @@ void llama_model_loader::print_info() const {
LLAMA_LOG_INFO("%s: file size = %.2f GiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0/1024.0, n_bytes*8.0/n_elements); LLAMA_LOG_INFO("%s: file size = %.2f GiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0/1024.0, n_bytes*8.0/n_elements);
} }
} }
std::vector<std::string> llama_get_list_splits(const std::string & path, const int n_split) {
std::vector<std::string> paths;
std::string split_prefix;
std::vector<char> buf(llama_path_max(), 0);
// brute force to find the split prefix
for (int idx = 0; idx < n_split; ++idx) {
int ret = llama_split_prefix(buf.data(), buf.size(), path.c_str(), idx, n_split);
if (ret) {
split_prefix = std::string(buf.data(), ret);
}
}
if (split_prefix.empty()) {
throw std::runtime_error(format("invalid split file: %s", path.c_str()));
}
for (int idx = 0; idx < n_split; ++idx) {
int ret = llama_split_path(buf.data(), buf.size(), split_prefix.c_str(), idx, n_split);
paths.push_back(std::string(buf.data(), ret));
}
return paths;
}

View file

@ -165,7 +165,3 @@ struct llama_model_loader {
void print_info() const; void print_info() const;
}; };
// return a list of splits for a given path
// for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
std::vector<std::string> llama_get_list_splits(const std::string & path, const int n_split);

View file

@ -9496,6 +9496,7 @@ static struct llama_model * llama_model_load_from_file_impl(
return model; return model;
} }
// deprecated
struct llama_model * llama_load_model_from_file( struct llama_model * llama_load_model_from_file(
const char * path_model, const char * path_model,
struct llama_model_params params) { struct llama_model_params params) {