split by max size

This commit is contained in:
ngxson 2024-03-27 12:57:34 +01:00
parent e82f9e2b83
commit 12b255487f

View file

@ -28,6 +28,7 @@ enum split_operation : uint8_t {
struct split_params {
split_operation operation = SPLIT_OP_SPLIT;
size_t n_bytes_split = 0;
int n_split_tensors = 128;
std::string input;
std::string output;
@ -43,12 +44,29 @@ static void split_print_usage(const char * executable) {
printf("options:\n");
printf(" -h, --help show this help message and exit\n");
printf(" --version show version and build info\n");
printf(" --split split GGUF to multiple GGUF (default)\n");
printf(" --split-max-tensors max tensors in each split: default(%d)\n", default_params.n_split_tensors);
printf(" --split split GGUF to multiple GGUF (enabled by default)\n");
printf(" --merge merge multiple GGUF to a single GGUF\n");
printf(" --split-max-tensors max tensors in each split (default: %d)\n", default_params.n_split_tensors);
printf(" --split-max-size N(M|G) max size per split\n");
printf("\n");
}
// return convert string, for example "128M" or "4G" to number of bytes
static size_t split_str_to_n_bytes(std::string str) {
size_t n_bytes = 0;
int n;
if (str.back() == 'M') {
sscanf(str.c_str(), "%d", &n);
n_bytes = n * 1024 * 1024; // megabytes
} else if (str.back() == 'G') {
sscanf(str.c_str(), "%d", &n);
n_bytes = n * 1024 * 1024 * 1024; // gigabytes
} else {
throw std::invalid_argument("error: supported units are M (megabytes) or G (gigabytes), but got " + str.back());
}
return n_bytes;
}
static bool split_params_parse_ex(int argc, const char ** argv, split_params & params) {
std::string arg;
const std::string arg_prefix = "--";
@ -62,6 +80,8 @@ static bool split_params_parse_ex(int argc, const char ** argv, split_params & p
}
bool arg_found = false;
bool is_op_set = false;
bool is_mode_set = false;
if (arg == "-h" || arg == "--help") {
split_print_usage(argv[0]);
exit(0);
@ -72,22 +92,41 @@ static bool split_params_parse_ex(int argc, const char ** argv, split_params & p
exit(0);
}
if (is_op_set) {
throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
}
if (arg == "--merge") {
arg_found = true;
is_op_set = true;
params.operation = SPLIT_OP_MERGE;
}
if (arg == "--split") {
arg_found = true;
is_op_set = true;
params.operation = SPLIT_OP_SPLIT;
}
if (is_mode_set) {
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
}
if (arg == "--split-max-tensors") {
if (++arg_idx >= argc) {
invalid_param = true;
break;
}
arg_found = true;
is_mode_set = true;
params.n_split_tensors = atoi(argv[arg_idx]);
}
if (arg == "--split-max-size") {
if (++arg_idx >= argc) {
invalid_param = true;
break;
}
arg_found = true;
is_mode_set = true;
params.n_bytes_split = split_str_to_n_bytes(argv[arg_idx]);
}
if (!arg_found) {
throw std::invalid_argument("error: unknown argument: " + arg);
@ -162,9 +201,23 @@ struct split_strategy {
n_split(std::ceil(1. * n_tensors / params.n_split_tensors)) {
}
bool should_split() const {
bool should_split() {
if (params.n_bytes_split > 0) {
// split by max size per file
size_t curr_size = fout.tellp();
if (i_tensor >= n_tensors - 1) {
return false;
}
// get size of next tensor
const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor + 1);
struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name);
size_t next_size = curr_size + ggml_nbytes(t);
return next_size > params.n_bytes_split;
} else {
// split by number of tensors per file
return i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0;
}
}
void split_start() {
ctx_out = gguf_init_empty();