batched-bench : migrate to gpt_params

This commit is contained in:
Georgi Gerganov 2024-06-04 11:27:13 +03:00
parent e7e0381411
commit a149eed043
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 114 additions and 117 deletions

View file

@ -283,6 +283,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} }
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
const char split_delim = ',';
llama_sampling_params & sparams = params.sparams; llama_sampling_params & sparams = params.sparams;
if (arg == "-s" || arg == "--seed") { if (arg == "-s" || arg == "--seed") {
@ -1153,6 +1155,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.ppl_stride = std::stoi(argv[i]); params.ppl_stride = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--ppl-output-type") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.ppl_output_type = std::stoi(argv[i]);
return true;
}
if (arg == "-ptc" || arg == "--print-token-count") { if (arg == "-ptc" || arg == "--print-token-count") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -1165,14 +1175,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.check_tensors = true; params.check_tensors = true;
return true; return true;
} }
if (arg == "--ppl-output-type") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.ppl_output_type = std::stoi(argv[i]);
return true;
}
if (arg == "--hellaswag") { if (arg == "--hellaswag") {
params.hellaswag = true; params.hellaswag = true;
return true; return true;
@ -1465,6 +1467,37 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.chat_template = argv[i]; params.chat_template = argv[i];
return true; return true;
} }
if (arg == "-pps") {
params.is_pp_shared = true;
return true;
}
if (arg == "-npp") {
if (++i >= argc) {
invalid_param = true;
return true;
}
auto p = string_split<int>(argv[i], split_delim);
params.n_pp.insert(params.n_pp.end(), p.begin(), p.end());
return true;
}
if (arg == "-ntg") {
if (++i >= argc) {
invalid_param = true;
return true;
}
auto p = string_split<int>(argv[i], split_delim);
params.n_tg.insert(params.n_tg.end(), p.begin(), p.end());
return true;
}
if (arg == "-npl") {
if (++i >= argc) {
invalid_param = true;
return true;
}
auto p = string_split<int>(argv[i], split_delim);
params.n_pl.insert(params.n_pl.end(), p.begin(), p.end());
return true;
}
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters // Parse args for logging parameters
if (log_param_single_parse(argv[i])) { if (log_param_single_parse(argv[i])) {
@ -1658,6 +1691,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "perplexity", " --multiple-choice-tasks N", options.push_back({ "perplexity", " --multiple-choice-tasks N",
"number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks }); "number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks });
options.push_back({ "perplexity", " --kl-divergence", "computes KL-divergence to logits provided via --kl-divergence-base" }); options.push_back({ "perplexity", " --kl-divergence", "computes KL-divergence to logits provided via --kl-divergence-base" });
options.push_back({ "perplexity", " --ppl-stride N", "stride for perplexity calculation (default: %d)", params.ppl_stride });
options.push_back({ "perplexity", " --ppl-output-type {0,1}",
"output type for perplexity calculation (default: %d)", params.ppl_output_type });
options.push_back({ "parallel" }); options.push_back({ "parallel" });
options.push_back({ "*", "-dt, --defrag-thold N", "KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold }); options.push_back({ "*", "-dt, --defrag-thold N", "KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold });
@ -1718,6 +1754,12 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" }); options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" }); options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
options.push_back({ "bench" });
options.push_back({ "bench", "-pps", "is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false" });
options.push_back({ "bench", "-npp n0,n1,...", "number of prompt tokens" });
options.push_back({ "bench", "-ntg n0,n1,...", "number of text generation tokens" });
options.push_back({ "bench", "-npl n0,n1,...", "number of parallel prompts" });
options.push_back({ "server" }); options.push_back({ "server" });
options.push_back({ "server", " --host HOST", "ip address to listen (default: %s)", params.hostname.c_str() }); options.push_back({ "server", " --host HOST", "ip address to listen (default: %s)", params.hostname.c_str() });
options.push_back({ "server", " --port PORT", "port to listen (default: %d)", params.port }); options.push_back({ "server", " --port PORT", "port to listen (default: %d)", params.port });

View file

@ -127,8 +127,8 @@ struct gpt_params {
int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_start = -1; // layer range for control vector
int32_t control_vector_layer_end = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector
int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
// (which is more convenient to use for plotting) // (which is more convenient to use for plotting)
// //
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
@ -203,6 +203,13 @@ struct gpt_params {
bool log_json = false; bool log_json = false;
std::string slot_save_path; std::string slot_save_path;
// batched-bench params
bool is_pp_shared = false;
std::vector<int32_t> n_pp;
std::vector<int32_t> n_tg;
std::vector<int32_t> n_pl;
}; };
void gpt_params_handle_model_default(gpt_params & params); void gpt_params_handle_model_default(gpt_params & params);
@ -223,6 +230,20 @@ std::vector<std::string> string_split(std::string input, char separator);
std::string string_strip(const std::string & str); std::string string_strip(const std::string & str);
std::string string_get_sortable_timestamp(); std::string string_get_sortable_timestamp();
template<class T>
static std::vector<T> string_split(const std::string & str, char delim) {
std::vector<T> values;
std::istringstream str_stream(str);
std::string token;
while (std::getline(str_stream, token, delim)) {
T value;
std::istringstream token_stream(token);
token_stream >> value;
values.push_back(value);
}
return values;
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides); bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input); void string_process_escapes(std::string & input);

View file

@ -10,16 +10,16 @@ There are 2 modes of operation:
- `prompt is shared` - there is a common prompt of size `PP` used by all batches (i.e. `N_KV = PP + B*TG`) - `prompt is shared` - there is a common prompt of size `PP` used by all batches (i.e. `N_KV = PP + B*TG`)
```bash ```bash
./batched-bench MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] [MMQ] <PP> <TG> <PL> ./batched-bench -m model.gguf -c 2048 -b 2048 -ub 512 -npp 128,256,512 -ntg 128,256 -npl 1,2,4,8,16,32 [-pps]
# LLaMA 7B, F16, N_KV_MAX = 16384 (8GB), prompt not shared # LLaMA 7B, F16, N_KV_MAX = 16384 (8GB), prompt not shared
./batched-bench ./models/llama-7b/ggml-model-f16.gguf 16384 2048 512 0 99 ./batched-bench -m ./models/llama-7b/ggml-model-f16.gguf -c 16384 -b 2048 -ub 512 -ngl 99
# LLaMA 7B, Q8_0, N_KV_MAX = 16384 (8GB), prompt is shared # LLaMA 7B, Q8_0, N_KV_MAX = 16384 (8GB), prompt is shared
./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 16384 2048 512 1 99 ./batched-bench -m ./models/llama-7b/ggml-model-q8_0.gguf -c 16384 -b 2048 -ub 512 -ngl 99 -pps
# custom set of batches # custom set of batches
./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 2048 512 512 0 999 0 128,256,512 128,256 1,2,4,8,16,32 ./batched-bench -m ./models/llama-7b/ggml-model-q8_0.gguf -c 2048 -b 512 -ub 512 -ngl 999 -npp 128,256,512 -ntg 128,256 -npl 1,2,4,8,16,32
``` ```
## Sample results ## Sample results

View file

@ -28,67 +28,27 @@ static std::vector<int> parse_list(char * p) {
return ret; return ret;
} }
static void print_usage(int argc, char ** argv, const gpt_params & params) {
gpt_params_print_usage(argc, argv, params);
LOG_TEE("\nexample usage:\n");
LOG_TEE("\n %s -m model.gguf -c 2048 -b 2048 -ub 512 -npp 128,256,512 -ntg 128,256 -npl 1,2,4,8,16,32 [-pps]\n", argv[0]);
LOG_TEE("\n");
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
if (argc == 1 || argv[1][0] == '-') { if (!gpt_params_parse(argc, argv, params)) {
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]); print_usage(argc, argv, params);
printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
return 1; return 1;
} }
int n_kv_max = 2048; int is_pp_shared = params.is_pp_shared;
int n_batch = 2048;
int n_ubatch = 512;
bool flash_attn = false;
int is_pp_shared = 0;
int n_gpu_layers = 0;
std::vector<int> n_pp = { 128, 256, 512, 1024, 2048, 3584, 7680, }; std::vector<int> n_pp = params.n_pp;
std::vector<int> n_tg = { 128, 256, }; std::vector<int> n_tg = params.n_tg;
std::vector<int> n_pl = { 1, 2, 4, 8, 16, 32, }; std::vector<int> n_pl = params.n_pl;
//std::vector<int> n_pl = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 32, };
if (argc >= 2) {
params.model = argv[1];
}
if (argc >= 3) {
n_kv_max = std::atoi(argv[2]);
}
if (argc >= 4) {
n_batch = std::atoi(argv[3]);
}
if (argc >= 5) {
n_ubatch = std::atoi(argv[4]);
}
if (argc >= 6) {
flash_attn = std::atoi(argv[5]);
}
if (argc >= 7) {
is_pp_shared = std::atoi(argv[6]);
}
if (argc >= 8) {
n_gpu_layers = std::atoi(argv[7]);
}
if (argc >= 9) {
n_pp = parse_list(argv[8]);
}
if (argc >= 10) {
n_tg = parse_list(argv[9]);
}
if (argc >= 11) {
n_pl = parse_list(argv[10]);
}
// init LLM // init LLM
@ -97,12 +57,7 @@ int main(int argc, char ** argv) {
// initialize the model // initialize the model
llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_params_from_gpt_params(params);
const std::vector<float> t_split(llama_max_devices(), 0.0f);
model_params.n_gpu_layers = n_gpu_layers;
model_params.tensor_split = t_split.data();
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
@ -111,16 +66,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_params_from_gpt_params(params);
ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_max;
ctx_params.n_batch = n_batch;
ctx_params.n_ubatch = n_ubatch;
ctx_params.flash_attn = flash_attn;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
// ensure enough sequences are available // ensure enough sequences are available
ctx_params.n_seq_max = *std::max_element(n_pl.begin(), n_pl.end()); ctx_params.n_seq_max = *std::max_element(n_pl.begin(), n_pl.end());
@ -132,6 +78,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const int32_t n_kv_max = llama_n_ctx(ctx);
llama_batch batch = llama_batch_init(n_kv_max, 0, 1); llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
// decode in batches of ctx_params.n_batch tokens // decode in batches of ctx_params.n_batch tokens
@ -175,7 +123,7 @@ int main(int argc, char ** argv) {
} }
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s"); LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");

View file

@ -41,20 +41,6 @@ static std::string join(const std::vector<T> & values, const std::string & delim
return str.str(); return str.str();
} }
template<class T>
static std::vector<T> split(const std::string & str, char delim) {
std::vector<T> values;
std::istringstream str_stream(str);
std::string token;
while (std::getline(str_stream, token, delim)) {
T value;
std::istringstream token_stream(token);
token_stream >> value;
values.push_back(value);
}
return values;
}
template<typename T, typename F> template<typename T, typename F>
static std::vector<std::string> transform_to_str(const std::vector<T> & values, F f) { static std::vector<std::string> transform_to_str(const std::vector<T> & values, F f) {
std::vector<std::string> str_values; std::vector<std::string> str_values;
@ -300,28 +286,28 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<std::string>(argv[i], split_delim); auto p = string_split<std::string>(argv[i], split_delim);
params.model.insert(params.model.end(), p.begin(), p.end()); params.model.insert(params.model.end(), p.begin(), p.end());
} else if (arg == "-p" || arg == "--n-prompt") { } else if (arg == "-p" || arg == "--n-prompt") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<int>(argv[i], split_delim); auto p = string_split<int>(argv[i], split_delim);
params.n_prompt.insert(params.n_prompt.end(), p.begin(), p.end()); params.n_prompt.insert(params.n_prompt.end(), p.begin(), p.end());
} else if (arg == "-n" || arg == "--n-gen") { } else if (arg == "-n" || arg == "--n-gen") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<int>(argv[i], split_delim); auto p = string_split<int>(argv[i], split_delim);
params.n_gen.insert(params.n_gen.end(), p.begin(), p.end()); params.n_gen.insert(params.n_gen.end(), p.begin(), p.end());
} else if (arg == "-pg") { } else if (arg == "-pg") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<std::string>(argv[i], ','); auto p = string_split<std::string>(argv[i], ',');
if (p.size() != 2) { if (p.size() != 2) {
invalid_param = true; invalid_param = true;
break; break;
@ -332,21 +318,21 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<int>(argv[i], split_delim); auto p = string_split<int>(argv[i], split_delim);
params.n_batch.insert(params.n_batch.end(), p.begin(), p.end()); params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
} else if (arg == "-ub" || arg == "--ubatch-size") { } else if (arg == "-ub" || arg == "--ubatch-size") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<int>(argv[i], split_delim); auto p = string_split<int>(argv[i], split_delim);
params.n_ubatch.insert(params.n_ubatch.end(), p.begin(), p.end()); params.n_ubatch.insert(params.n_ubatch.end(), p.begin(), p.end());
} else if (arg == "-ctk" || arg == "--cache-type-k") { } else if (arg == "-ctk" || arg == "--cache-type-k") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<std::string>(argv[i], split_delim); auto p = string_split<std::string>(argv[i], split_delim);
std::vector<ggml_type> types; std::vector<ggml_type> types;
for (const auto & t : p) { for (const auto & t : p) {
ggml_type gt = ggml_type_from_name(t); ggml_type gt = ggml_type_from_name(t);
@ -362,7 +348,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<std::string>(argv[i], split_delim); auto p = string_split<std::string>(argv[i], split_delim);
std::vector<ggml_type> types; std::vector<ggml_type> types;
for (const auto & t : p) { for (const auto & t : p) {
ggml_type gt = ggml_type_from_name(t); ggml_type gt = ggml_type_from_name(t);
@ -378,14 +364,14 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<int>(argv[i], split_delim); auto p = string_split<int>(argv[i], split_delim);
params.n_threads.insert(params.n_threads.end(), p.begin(), p.end()); params.n_threads.insert(params.n_threads.end(), p.begin(), p.end());
} else if (arg == "-ngl" || arg == "--n-gpu-layers") { } else if (arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<int>(argv[i], split_delim); auto p = string_split<int>(argv[i], split_delim);
params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end()); params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end());
} else if (arg == "-rpc" || arg == "--rpc") { } else if (arg == "-rpc" || arg == "--rpc") {
if (++i >= argc) { if (++i >= argc) {
@ -398,7 +384,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<std::string>(argv[i], split_delim); auto p = string_split<std::string>(argv[i], split_delim);
std::vector<llama_split_mode> modes; std::vector<llama_split_mode> modes;
for (const auto & m : p) { for (const auto & m : p) {
llama_split_mode mode; llama_split_mode mode;
@ -420,13 +406,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.main_gpu = split<int>(argv[i], split_delim); params.main_gpu = string_split<int>(argv[i], split_delim);
} else if (arg == "-nkvo" || arg == "--no-kv-offload") { } else if (arg == "-nkvo" || arg == "--no-kv-offload") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<bool>(argv[i], split_delim); auto p = string_split<bool>(argv[i], split_delim);
params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end()); params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
} else if (arg == "--numa") { } else if (arg == "--numa") {
if (++i >= argc) { if (++i >= argc) {
@ -444,28 +430,28 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<bool>(argv[i], split_delim); auto p = string_split<bool>(argv[i], split_delim);
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") { } else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<bool>(argv[i], split_delim); auto p = string_split<bool>(argv[i], split_delim);
params.use_mmap.insert(params.use_mmap.end(), p.begin(), p.end()); params.use_mmap.insert(params.use_mmap.end(), p.begin(), p.end());
} else if (arg == "-embd" || arg == "--embeddings") { } else if (arg == "-embd" || arg == "--embeddings") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
auto p = split<bool>(argv[i], split_delim); auto p = string_split<bool>(argv[i], split_delim);
params.embeddings.insert(params.embeddings.end(), p.begin(), p.end()); params.embeddings.insert(params.embeddings.end(), p.begin(), p.end());
} else if (arg == "-ts" || arg == "--tensor-split") { } else if (arg == "-ts" || arg == "--tensor-split") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
for (auto ts : split<std::string>(argv[i], split_delim)) { for (auto ts : string_split<std::string>(argv[i], split_delim)) {
// split string by ; and / // split string by ; and /
const std::regex regex{R"([;/]+)"}; const std::regex regex{R"([;/]+)"};
std::sregex_token_iterator it{ts.begin(), ts.end(), regex, -1}; std::sregex_token_iterator it{ts.begin(), ts.end(), regex, -1};