Rename truthful_qa to multiple_choice

This commit is contained in:
Iwan Kawrakow 2024-01-21 11:53:38 +02:00
parent d86f80f416
commit 92540e44c2
3 changed files with 42 additions and 30 deletions

View file

@ -708,14 +708,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.winogrande_tasks = std::stoi(argv[i]); params.winogrande_tasks = std::stoi(argv[i]);
} else if (arg == "--truthful-qa") { } else if (arg == "--multiple-choice") {
params.truthful_qa = true; params.multiple_choice = true;
} else if (arg == "--truthful-qa-tasks") { } else if (arg == "--multiple-choice-tasks") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.thruthful_qa_tasks = std::stoi(argv[i]); params.multiple_choice_tasks = std::stoi(argv[i]);
} else if (arg == "--ignore-eos") { } else if (arg == "--ignore-eos") {
params.ignore_eos = true; params.ignore_eos = true;
} else if (arg == "--no-penalize-nl") { } else if (arg == "--no-penalize-nl") {
@ -915,6 +915,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); printf(" --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
printf(" -f FNAME, --file FNAME\n"); printf(" -f FNAME, --file FNAME\n");
printf(" prompt file to start generation.\n"); printf(" prompt file to start generation.\n");
printf(" -bf FNAME, --binary-file FNAME\n");
printf(" binary file containing multiple choice tasks.\n");
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
@ -963,8 +965,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
printf(" --winogrande compute Winogrande score over random tasks from datafile supplied with -f\n"); printf(" --winogrande compute Winogrande score over random tasks from datafile supplied with -f\n");
printf(" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks); printf(" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks);
printf(" --truthful-qa compute TruthFullQA multiple choice score over random tasks from datafile supplied with -f\n"); printf(" --multiple-choice compute multiple choice score over random tasks from datafile supplied with -f\n");
printf(" --truthful-qa-tasks N number of tasks to use when computing the TruthFullQA score (default: %zu)\n", params.winogrande_tasks); printf(" --multiple-choice-tasks N number of tasks to use when computing the multiple choice score (default: %zu)\n", params.winogrande_tasks);
printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);

View file

@ -108,8 +108,8 @@ struct gpt_params {
bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt
size_t winogrande_tasks= 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed size_t winogrande_tasks= 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed
bool truthful_qa = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt
size_t thruthful_qa_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
bool random_prompt = false; // do not randomize prompt if none provided bool random_prompt = false; // do not randomize prompt if none provided

View file

@ -1040,7 +1040,7 @@ static bool deserialize_string(std::istream& in, std::string& str) {
return false; return false;
} }
struct truthful_qa_answer { struct multiple_choice_answers {
std::vector<std::string> answers; std::vector<std::string> answers;
std::vector<int> labels; std::vector<int> labels;
bool deserialize(std::istream& in) { bool deserialize(std::istream& in) {
@ -1057,10 +1057,10 @@ struct truthful_qa_answer {
} }
}; };
struct truthful_qa_task { struct multiple_choice_task {
std::string question; std::string question; // the question (or context that needs to be continued)
truthful_qa_answer mc1; multiple_choice_answers mc1; // possible answers (continuations) with a single correct answer
truthful_qa_answer mc2; multiple_choice_answers mc2; // possible answers (continuations) with multiple correct answers - not handled yet
bool deserialize(std::istream& in) { bool deserialize(std::istream& in) {
if (!deserialize_string(in, question)) return false; if (!deserialize_string(in, question)) return false;
return mc1.deserialize(in) && mc2.deserialize(in); return mc1.deserialize(in) && mc2.deserialize(in);
@ -1074,7 +1074,7 @@ struct truthful_qa_task {
std::vector<float> log_probs; std::vector<float> log_probs;
}; };
static bool truthful_qa_prepare_one_task(llama_context * ctx, bool add_bos, truthful_qa_task& task, bool log_error) { static bool multiple_choice_prepare_one_task(llama_context * ctx, bool add_bos, multiple_choice_task& task, bool log_error) {
if (task.question.empty() || task.mc1.answers.empty()) { if (task.question.empty() || task.mc1.answers.empty()) {
if (log_error) { if (log_error) {
printf("%s: found bad task with empty question and/or answers\n", __func__); printf("%s: found bad task with empty question and/or answers\n", __func__);
@ -1117,13 +1117,23 @@ static bool truthful_qa_prepare_one_task(llama_context * ctx, bool add_bos, trut
return true; return true;
} }
static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
// Calculates TruthFulQA score (multiple choice with single correct answer) from prompt
// //
// Data extracted from https://huggingface.co/datasets/truthful_qa // Calculates score for multiple choice tasks with single correct answer from prompt.
// The validation dataset in the binary format that is being used can be found at // Commonly used LLM evaluation metrics of this type are
// * ARC
// * HellaSwag
// * MMLU
// * TruthfulQA
//
// Validation datasets for these 4 tests can be found at
// https://huggingface.co/datasets/ikawrakow/validation-datasets-for-llama.cpp // https://huggingface.co/datasets/ikawrakow/validation-datasets-for-llama.cpp
// The data for these datasets was extracted from
// git@hf.co:datasets/allenai/ai2_arc
// https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
// git@hf.co:datasets/Stevross/mmlu
// https://huggingface.co/datasets/truthful_qa
// //
static void multiple_choice_score(llama_context * ctx, const gpt_params & params) {
std::istringstream strstream(params.prompt); std::istringstream strstream(params.prompt);
uint32_t n_task; uint32_t n_task;
@ -1140,8 +1150,8 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
return; return;
} }
std::vector<truthful_qa_task> tasks; std::vector<multiple_choice_task> tasks;
if (params.thruthful_qa_tasks == 0 || params.thruthful_qa_tasks >= (size_t)n_task) { if (params.multiple_choice_tasks == 0 || params.multiple_choice_tasks >= (size_t)n_task) {
// Use all tasks // Use all tasks
tasks.resize(n_task); tasks.resize(n_task);
printf("%s: reading tasks", __func__); printf("%s: reading tasks", __func__);
@ -1158,12 +1168,12 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
printf("done\n"); printf("done\n");
} }
else { else {
printf("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.thruthful_qa_tasks, n_task); printf("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task);
std::mt19937 rng(1); std::mt19937 rng(1);
std::vector<int> aux(n_task); std::vector<int> aux(n_task);
for (uint32_t i = 0; i < n_task; ++i) aux[i] = i; for (uint32_t i = 0; i < n_task; ++i) aux[i] = i;
float scale = 1.f/(1.f + (float)std::mt19937::max()); float scale = 1.f/(1.f + (float)std::mt19937::max());
tasks.resize(params.thruthful_qa_tasks); tasks.resize(params.multiple_choice_tasks);
for (auto& task : tasks) { for (auto& task : tasks) {
int j = (int)(scale * rng() * aux.size()); int j = (int)(scale * rng() * aux.size());
int idx = aux[j]; int idx = aux[j];
@ -1175,7 +1185,7 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
return; return;
} }
} }
n_task = params.thruthful_qa_tasks; n_task = params.multiple_choice_tasks;
} }
// This is needed as usual for LLaMA models // This is needed as usual for LLaMA models
@ -1200,7 +1210,7 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
} }
int last = std::min(first + k_chunk, num_tasks); int last = std::min(first + k_chunk, num_tasks);
for (int i = first; i < last; ++i) { for (int i = first; i < last; ++i) {
if (!truthful_qa_prepare_one_task(ctx, add_bos, tasks[i], false)) ++n_bad_local; if (!multiple_choice_prepare_one_task(ctx, add_bos, tasks[i], false)) ++n_bad_local;
} }
} }
}; };
@ -1222,7 +1232,7 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
int i_task = 0; int i_task = 0;
for (auto& task : tasks) { for (auto& task : tasks) {
++i_task; ++i_task;
if (!truthful_qa_prepare_one_task(ctx, add_bos, task, true)) { if (!multiple_choice_prepare_one_task(ctx, add_bos, task, true)) {
return; return;
} }
if (i_task%n_dot == 0) { if (i_task%n_dot == 0) {
@ -1465,8 +1475,8 @@ int main(int argc, char ** argv) {
hellaswag_score(ctx, params); hellaswag_score(ctx, params);
} else if (params.winogrande) { } else if (params.winogrande) {
winogrande_score(ctx, params); winogrande_score(ctx, params);
} else if (params.truthful_qa) { } else if (params.multiple_choice) {
truthful_qa_score(ctx, params); multiple_choice_score(ctx, params);
} else { } else {
results = perplexity(ctx, params); results = perplexity(ctx, params);
} }