sampling : rename penalty params + reduce size of "prev" vector
This commit is contained in:
parent
84ed48b473
commit
b526561583
7 changed files with 86 additions and 79 deletions
|
@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.repeat_last_n = std::stoi(argv[i]);
|
||||
sparams.penalty_last_n = std::stoi(argv[i]);
|
||||
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
|
||||
} else if (arg == "--repeat-penalty") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.repeat_penalty = std::stof(argv[i]);
|
||||
sparams.penalty_repeat = std::stof(argv[i]);
|
||||
} else if (arg == "--frequency-penalty") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.frequency_penalty = std::stof(argv[i]);
|
||||
sparams.penalty_freq = std::stof(argv[i]);
|
||||
} else if (arg == "--presence-penalty") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.presence_penalty = std::stof(argv[i]);
|
||||
sparams.penalty_present = std::stof(argv[i]);
|
||||
} else if (arg == "--mirostat") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -678,10 +679,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
|
||||
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
|
||||
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
|
||||
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n);
|
||||
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty);
|
||||
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty);
|
||||
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty);
|
||||
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
|
||||
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
|
||||
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
|
||||
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
|
||||
printf(" --mirostat N use Mirostat sampling.\n");
|
||||
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
|
||||
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
|
||||
|
@ -1178,7 +1179,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
||||
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
||||
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
||||
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty);
|
||||
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
|
||||
dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
|
||||
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
|
||||
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
|
||||
|
@ -1238,14 +1239,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
|
||||
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
|
||||
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
|
||||
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty);
|
||||
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
|
||||
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
|
||||
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
|
||||
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
|
||||
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
|
||||
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
|
||||
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
|
||||
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty);
|
||||
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);
|
||||
|
||||
fprintf(stream, "reverse_prompt:\n");
|
||||
for (std::string ap : params.antiprompt) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue