Alternative way to output PPL results

This commit is contained in:
Iwan Kawrakow 2023-08-22 16:12:45 +03:00
parent b791d1f489
commit efdfc41e49
3 changed files with 19 additions and 3 deletions

View file

@ -424,6 +424,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.ppl_stride = std::stoi(argv[i]); params.ppl_stride = std::stoi(argv[i]);
} else if (arg == "--ppl-output-type") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.ppl_output_type = std::stoi(argv[i]);
} else if (arg == "--hellaswag") { } else if (arg == "--hellaswag") {
params.hellaswag = true; params.hellaswag = true;
} else if (arg == "--hellaswag-tasks") { } else if (arg == "--hellaswag-tasks") {

View file

@ -65,7 +65,9 @@ struct gpt_params {
std::string lora_base = ""; // base model path for the lora adapter std::string lora_base = ""; // base model path for the lora adapter
int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
// int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
// (which is more convenient to use for plotting)
//
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score

View file

@ -125,7 +125,11 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
++count; ++count;
} }
// perplexity is e^(average negative log-likelihood) // perplexity is e^(average negative log-likelihood)
printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); if (params.ppl_output_type == 0) {
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
} else {
printf("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
}
fflush(stdout); fflush(stdout);
} }
printf("\n"); printf("\n");
@ -226,7 +230,11 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
++count; ++count;
} }
// perplexity is e^(average negative log-likelihood) // perplexity is e^(average negative log-likelihood)
printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); if (params.ppl_output_type == 0) {
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
} else {
printf("%8d %.4lf\n", i*params.n_ctx, std::exp(nll / count));
}
fflush(stdout); fflush(stdout);
} }
printf("\n"); printf("\n");