From 3fa6726351fc267ebe7e0524220d064b3d68f1d9 Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Sun, 3 Dec 2023 12:46:09 +0500 Subject: [PATCH] More readable samplers input string, fixed help --- common/common.cpp | 50 +++++++++++++++++++++++++++++++++++++++++++++++ common/common.h | 6 ++++++ 2 files changed, 56 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index b148e3a78..75588e156 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -280,6 +280,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.yarn_beta_slow = std::stof(argv[i]); } else if (arg == "--memory-f32") { params.memory_f16 = false; + } else if (arg == "--samplers") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.samplers_sequence = parse_samplers_input(argv[i]); } else if (arg == "--sampling-seq") { if (++i >= argc) { invalid_param = true; @@ -767,6 +773,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); + printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n"); + printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str()); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); @@ -892,6 +900,48 @@ std::string gpt_random_prompt(std::mt19937 & rng) { GGML_UNREACHABLE(); } +// +// String parsing +// + +std::string parse_samplers_input(std::string input){ + std::string output = ""; + // since samplers names are written multiple ways + // make it ready for both system names and input names + std::unordered_map samplers_symbols{ + {"top_k", 'k'}, + {"top-k", 'k'}, + {"top_p", 'p'}, + {"top-p", 'p'}, + {"nucleus", 'p'}, + {"typical_p", 'y'}, + {"typical-p", 'y'}, + {"typical", 'y'}, + {"min_p", 'm'}, + {"min-p", 'm'}, + {"tfs_z", 'f'}, + {"tfs-z", 'f'}, + {"tfs", 'f'}, + {"temp", 't'}, + {"temperature",'t'} + }; + // expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p" + size_t separator = input.find(';'); + while (separator != input.npos){ + std::string name = input.substr(0,separator); + input = input.substr(separator+1); + separator = input.find(';'); + + if (samplers_symbols.find(name) != samplers_symbols.end()){ + output += samplers_symbols[name]; + } + } + if (samplers_symbols.find(input) != samplers_symbols.end()){ + output += samplers_symbols[input]; + } + return output; +} + // // Model utils // diff --git a/common/common.h b/common/common.h index 2f6fe48ab..534f7b132 100644 --- a/common/common.h +++ b/common/common.h @@ -141,6 +141,12 @@ std::string gpt_random_prompt(std::mt19937 & rng); void process_escapes(std::string& input); +// +// String parsing +// + +std::string parse_samplers_input(std::string input); + // // Model utils //