From e474e456ebaa5a169d7ea6d12ddb9a7c4087d971 Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Fri, 22 Mar 2024 07:48:50 +0100 Subject: [PATCH] llama_split_prefix: use a clearer version, not pass split path len but dest max len. Co-authored-by: Xuan Son Nguyen --- examples/gguf-split/gguf-split.cpp | 2 +- llama.cpp | 30 ++++++++++++------------------ llama.h | 4 ++-- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/examples/gguf-split/gguf-split.cpp b/examples/gguf-split/gguf-split.cpp index 3f582506d..f703588e1 100644 --- a/examples/gguf-split/gguf-split.cpp +++ b/examples/gguf-split/gguf-split.cpp @@ -355,7 +355,7 @@ static void gguf_merge(const split_params & split_params) { } // Verify the file naming and extract split_prefix - if (!llama_split_prefix(split_prefix, split_path, strlen(split_path), i_split, n_split)) { + if (!llama_split_prefix(split_prefix, sizeof (split_prefix), split_path, i_split, n_split)) { fprintf(stderr, "\n%s: unexpected input file name: %s" " i_split=%d" " n_split=%d\n", __func__, diff --git a/llama.cpp b/llama.cpp index 092eae8f6..ee0318feb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2888,7 +2888,7 @@ struct llama_model_loader { } char split_prefix[PATH_MAX] = {0}; - if (!llama_split_prefix(split_prefix, fname.c_str(), fname.size(), idx, n_split)) { + if (!llama_split_prefix(split_prefix, sizeof(split_prefix), fname.c_str(), idx, n_split)) { throw std::runtime_error(format("invalid split file: %s", fname.c_str())); } @@ -14806,25 +14806,19 @@ LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * pa return 0; } -LLAMA_API int llama_split_prefix(char * dest, const char * split_path, size_t split_path_len, int split_no, int split_count) { - char split_prefix[PATH_MAX] = {0}; - int split_no_file = 0; - int split_count_file = 0; - const char * split_format = "-00000-of-00000.gguf"; +int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) { + std::string str_split_path(split_path); + char postfix[32]; + sprintf(postfix, "-%05d-of-%05d.gguf", split_no + 1, split_count); + std::string str_postfix(postfix); - if (split_path_len > strlen(split_format) + 1) { - size_t prefix_len = split_path_len - strlen(split_format); - if (prefix_len >= sizeof(split_prefix)) { - prefix_len = sizeof(split_prefix) - 1; // leave room for null terminator - } - strncpy(split_prefix, split_path, prefix_len); - - int n = sscanf(&split_path[0] + strlen(split_prefix), "-%d-of-%d", &split_no_file, &split_count_file); - if (n == 2 && split_no_file - 1 == split_no && split_count_file == split_count) { - strcpy(dest, split_prefix); - return strlen(split_prefix); - } + // check if dest ends with postfix + auto size_prefix = str_split_path.size() - str_postfix.size(); + if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) { + strncpy(dest, split_path, std::min(size_prefix, maxlen)); + return size_prefix; } + return 0; } diff --git a/llama.h b/llama.h index c23172c55..7e8ac4b62 100644 --- a/llama.h +++ b/llama.h @@ -966,9 +966,9 @@ extern "C" { LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count); /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match. - /// llama_split_prefix(split_prefix, "/models/ggml-model-q4_0-00002-of-00004.gguf", 43, 2, 4) => split_prefix = "/models/ggml-model-q4_0" + /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" // Returns the split_prefix length. - LLAMA_API int llama_split_prefix(char * split_prefix, const char * split_path, size_t split_path_len, int split_no, int split_count); + LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count); // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);