diff --git a/examples/gguf-split/gguf-split.cpp b/examples/gguf-split/gguf-split.cpp index b1af59992..b03b38b4c 100644 --- a/examples/gguf-split/gguf-split.cpp +++ b/examples/gguf-split/gguf-split.cpp @@ -28,6 +28,7 @@ enum split_operation : uint8_t { struct split_params { split_operation operation = SPLIT_OP_SPLIT; + size_t n_bytes_split = 0; int n_split_tensors = 128; std::string input; std::string output; @@ -41,14 +42,31 @@ static void split_print_usage(const char * executable) { printf("Apply a GGUF operation on IN to OUT."); printf("\n"); printf("options:\n"); - printf(" -h, --help show this help message and exit\n"); - printf(" --version show version and build info\n"); - printf(" --split split GGUF to multiple GGUF (default)\n"); - printf(" --split-max-tensors max tensors in each split: default(%d)\n", default_params.n_split_tensors); - printf(" --merge merge multiple GGUF to a single GGUF\n"); + printf(" -h, --help show this help message and exit\n"); + printf(" --version show version and build info\n"); + printf(" --split split GGUF to multiple GGUF (enabled by default)\n"); + printf(" --merge merge multiple GGUF to a single GGUF\n"); + printf(" --split-max-tensors max tensors in each split (default: %d)\n", default_params.n_split_tensors); + printf(" --split-max-size N(M|G) max size per split\n"); printf("\n"); } +// return convert string, for example "128M" or "4G" to number of bytes +static size_t split_str_to_n_bytes(std::string str) { + size_t n_bytes = 0; + int n; + if (str.back() == 'M') { + sscanf(str.c_str(), "%d", &n); + n_bytes = n * 1024 * 1024; // megabytes + } else if (str.back() == 'G') { + sscanf(str.c_str(), "%d", &n); + n_bytes = n * 1024 * 1024 * 1024; // gigabytes + } else { + throw std::invalid_argument("error: supported units are M (megabytes) or G (gigabytes), but got " + str.back()); + } + return n_bytes; +} + static bool split_params_parse_ex(int argc, const char ** argv, split_params & params) { std::string arg; const std::string arg_prefix = "--"; @@ -62,6 +80,8 @@ static bool split_params_parse_ex(int argc, const char ** argv, split_params & p } bool arg_found = false; + bool is_op_set = false; + bool is_mode_set = false; if (arg == "-h" || arg == "--help") { split_print_usage(argv[0]); exit(0); @@ -72,22 +92,41 @@ static bool split_params_parse_ex(int argc, const char ** argv, split_params & p exit(0); } + if (is_op_set) { + throw std::invalid_argument("error: either --split or --merge can be specified, but not both"); + } if (arg == "--merge") { arg_found = true; + is_op_set = true; params.operation = SPLIT_OP_MERGE; } if (arg == "--split") { arg_found = true; + is_op_set = true; params.operation = SPLIT_OP_SPLIT; } + + if (is_mode_set) { + throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both"); + } if (arg == "--split-max-tensors") { if (++arg_idx >= argc) { invalid_param = true; break; } arg_found = true; + is_mode_set = true; params.n_split_tensors = atoi(argv[arg_idx]); } + if (arg == "--split-max-size") { + if (++arg_idx >= argc) { + invalid_param = true; + break; + } + arg_found = true; + is_mode_set = true; + params.n_bytes_split = split_str_to_n_bytes(argv[arg_idx]); + } if (!arg_found) { throw std::invalid_argument("error: unknown argument: " + arg); @@ -162,8 +201,22 @@ struct split_strategy { n_split(std::ceil(1. * n_tensors / params.n_split_tensors)) { } - bool should_split() const { - return i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0; + bool should_split() { + if (params.n_bytes_split > 0) { + // split by max size per file + size_t curr_size = fout.tellp(); + if (i_tensor >= n_tensors - 1) { + return false; + } + // get size of next tensor + const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor + 1); + struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name); + size_t next_size = curr_size + ggml_nbytes(t); + return next_size > params.n_bytes_split; + } else { + // split by number of tensors per file + return i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0; + } } void split_start() {