TruthfulQA: works but the result is bad

I know it works because if I convert the HellaSwag validation
data to the binary format used in the truthful_qa_score() function
I get the exact same result as from the hellaswag_score() function.
But I guess, the questions are tricky and the way I have done
the combination of question + answer is very likely not the best.
The TruthfulQA validation dataset contains 817 questions, with
random chance result around 19%. With this version I get
29.1% for Mistral-7B and 55.2% for Mistral-7B-Instruct-v0.2.
The HF leader board results for these two models are
42.2% and 68.3%, respectively.
This commit is contained in:
Iwan Kawrakow 2024-01-20 10:26:15 +02:00
parent 6ce06623fd
commit b0a4873697

View file

@ -540,14 +540,14 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
// This is needed as usual for LLaMA models
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
// The tasks should be randomized so the score stabilizes quickly.
bool randomize_tasks = true;
// Number of tasks to use when computing the score
if (params.hellaswag_tasks < hs_task_count) {
hs_task_count = params.hellaswag_tasks;
}
// The tasks should be randomized so the score stabilizes quickly.
bool randomize_tasks = true;
// The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
std::mt19937 rng(1);
@ -1079,6 +1079,8 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
// Calculates TruthFulQA score (multiple choice with single correct answer) from prompt
//
// Data extracted from https://huggingface.co/datasets/truthful_qa
// The validation dataset in the binary format that is being used can be found at
// https://huggingface.co/datasets/ikawrakow/validation-datasets-for-llama.cpp
//
std::istringstream strstream(params.prompt);
@ -1207,6 +1209,7 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
int n_done = 0;
int n_correct = 0;
int n_tot_answers = 0;
for (size_t i0 = 0; i0 < tasks.size(); i0++) {
int n_cur = 0;
@ -1246,6 +1249,8 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
}
}
s0 += num_answers;
cur_task.i_batch = i_batch;
i_batch += cur_task.required_tokens;
@ -1289,6 +1294,13 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
// compute the logprobs for each ending of the decoded tasks
for (size_t i = i0; i < i1; ++i) {
auto & cur_task = tasks[i];
//printf("==== Evaluating <%s> with correct answer ", cur_task.question.c_str());
//for (int j = 0; j < int(cur_task.mc1.labels.size()); ++j) {
// if (cur_task.mc1.labels[j] == 1) {
// printf("%d", j+1);
// }
//}
//printf("\n common_prefix: %zu\n", cur_task.common_prefix);
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(cur_task.i_batch + cur_task.common_prefix - 1), n_vocab*sizeof(float));
@ -1298,11 +1310,29 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
size_t count = 1;
float log_prob = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
//printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob);
//for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
// printf(" %zu %g\n", ir, eval_results[ir]);
// ++count;
// log_prob += eval_results[ir++];
//}
//size_t count = 0;
//float log_prob = 0;
//printf(" <%s>\n", cur_task.mc1.answers[s].c_str());
//float log_prob = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
//printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob);
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
//printf(" %zu %g\n", ir, eval_results[ir]);
++count;
log_prob += eval_results[ir++];
}
//if (!count) {
// ++count;
// log_prob += std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
//}
cur_task.log_probs[s] = log_prob / count;
//printf(" Final: %g\n", log_prob / count);
//printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count);
}
// Find the ending with maximum logprob
@ -1315,13 +1345,14 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
}
}
n_tot_answers += cur_task.log_probs.size();
if (cur_task.mc1.labels[logprob_max_idx] == 1) {
++n_correct;
}
++n_done;
// Print the accumulated accuracy mean x 100
printf("%zu\t%.8lf\n", i + 1, 100.*n_correct/n_done);
printf("%d\t%.8lf\n", n_done, 100.*n_correct/n_done);
fflush(stdout);
}
@ -1330,6 +1361,15 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
llama_batch_free(batch);
if (n_done < 100) return;
float p = 1.f*n_correct/n_done;
float sigma = sqrt(p*(1-p)/(n_done-1));
printf("\n Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
p = 1.f*n_done/n_tot_answers;
sigma = sqrt(p*(1-p)/(n_done-1));
printf("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
printf("\n");
}