KL-divergence (#5076)
* kl-divergence: be able to save all logits to a file * Add ability to compute KL-divergence --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
780e24a22e
commit
6f9939d119
3 changed files with 329 additions and 2 deletions
|
@ -672,6 +672,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
if (params.logdir.back() != DIRECTORY_SEPARATOR) {
|
||||
params.logdir += DIRECTORY_SEPARATOR;
|
||||
}
|
||||
} else if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.logits_file = argv[i];
|
||||
} else if (arg == "--perplexity" || arg == "--all-logits") {
|
||||
params.logits_all = true;
|
||||
} else if (arg == "--ppl-stride") {
|
||||
|
@ -716,6 +722,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.multiple_choice_tasks = std::stoi(argv[i]);
|
||||
} else if (arg == "--kl-divergence") {
|
||||
params.kl_divergence = true;
|
||||
} else if (arg == "--ignore-eos") {
|
||||
params.ignore_eos = true;
|
||||
} else if (arg == "--no-penalize-nl") {
|
||||
|
@ -967,6 +975,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks);
|
||||
printf(" --multiple-choice compute multiple choice score over random tasks from datafile supplied with -f\n");
|
||||
printf(" --multiple-choice-tasks N number of tasks to use when computing the multiple choice score (default: %zu)\n", params.winogrande_tasks);
|
||||
printf(" --kl-divergence computes KL-divergence to logits provided via --kl-divergence-base");
|
||||
printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
||||
printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
|
||||
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue