More readable samplers input string, fixed help
This commit is contained in:
parent
e6dc166566
commit
3fa6726351
2 changed files with 56 additions and 0 deletions
|
@ -280,6 +280,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
params.yarn_beta_slow = std::stof(argv[i]);
|
params.yarn_beta_slow = std::stof(argv[i]);
|
||||||
} else if (arg == "--memory-f32") {
|
} else if (arg == "--memory-f32") {
|
||||||
params.memory_f16 = false;
|
params.memory_f16 = false;
|
||||||
|
} else if (arg == "--samplers") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sparams.samplers_sequence = parse_samplers_input(argv[i]);
|
||||||
} else if (arg == "--sampling-seq") {
|
} else if (arg == "--sampling-seq") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -767,6 +773,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
|
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
|
||||||
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
|
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
|
||||||
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||||
|
printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n");
|
||||||
|
printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str());
|
||||||
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
|
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
|
||||||
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
|
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
|
||||||
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
|
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
|
||||||
|
@ -892,6 +900,48 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
|
||||||
GGML_UNREACHABLE();
|
GGML_UNREACHABLE();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// String parsing
|
||||||
|
//
|
||||||
|
|
||||||
|
std::string parse_samplers_input(std::string input){
|
||||||
|
std::string output = "";
|
||||||
|
// since samplers names are written multiple ways
|
||||||
|
// make it ready for both system names and input names
|
||||||
|
std::unordered_map<std::string, char> samplers_symbols{
|
||||||
|
{"top_k", 'k'},
|
||||||
|
{"top-k", 'k'},
|
||||||
|
{"top_p", 'p'},
|
||||||
|
{"top-p", 'p'},
|
||||||
|
{"nucleus", 'p'},
|
||||||
|
{"typical_p", 'y'},
|
||||||
|
{"typical-p", 'y'},
|
||||||
|
{"typical", 'y'},
|
||||||
|
{"min_p", 'm'},
|
||||||
|
{"min-p", 'm'},
|
||||||
|
{"tfs_z", 'f'},
|
||||||
|
{"tfs-z", 'f'},
|
||||||
|
{"tfs", 'f'},
|
||||||
|
{"temp", 't'},
|
||||||
|
{"temperature",'t'}
|
||||||
|
};
|
||||||
|
// expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
|
||||||
|
size_t separator = input.find(';');
|
||||||
|
while (separator != input.npos){
|
||||||
|
std::string name = input.substr(0,separator);
|
||||||
|
input = input.substr(separator+1);
|
||||||
|
separator = input.find(';');
|
||||||
|
|
||||||
|
if (samplers_symbols.find(name) != samplers_symbols.end()){
|
||||||
|
output += samplers_symbols[name];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (samplers_symbols.find(input) != samplers_symbols.end()){
|
||||||
|
output += samplers_symbols[input];
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model utils
|
// Model utils
|
||||||
//
|
//
|
||||||
|
|
|
@ -141,6 +141,12 @@ std::string gpt_random_prompt(std::mt19937 & rng);
|
||||||
|
|
||||||
void process_escapes(std::string& input);
|
void process_escapes(std::string& input);
|
||||||
|
|
||||||
|
//
|
||||||
|
// String parsing
|
||||||
|
//
|
||||||
|
|
||||||
|
std::string parse_samplers_input(std::string input);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model utils
|
// Model utils
|
||||||
//
|
//
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue