TruthfulQA: works but the result is bad
I know it works because if I convert the HellaSwag validation data to the binary format used in the truthful_qa_score() function I get the exact same result as from the hellaswag_score() function. But I guess, the questions are tricky and the way I have done the combination of question + answer is very likely not the best. The TruthfulQA validation dataset contains 817 questions, with random chance result around 19%. With this version I get 29.1% for Mistral-7B and 55.2% for Mistral-7B-Instruct-v0.2. The HF leader board results for these two models are 42.2% and 68.3%, respectively.
This commit is contained in:
parent
6ce06623fd
commit
b0a4873697
1 changed files with 44 additions and 4 deletions
|
@ -540,14 +540,14 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
// This is needed as usual for LLaMA models
|
||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||
|
||||
// The tasks should be randomized so the score stabilizes quickly.
|
||||
bool randomize_tasks = true;
|
||||
|
||||
// Number of tasks to use when computing the score
|
||||
if (params.hellaswag_tasks < hs_task_count) {
|
||||
hs_task_count = params.hellaswag_tasks;
|
||||
}
|
||||
|
||||
// The tasks should be randomized so the score stabilizes quickly.
|
||||
bool randomize_tasks = true;
|
||||
|
||||
// The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
|
||||
std::mt19937 rng(1);
|
||||
|
||||
|
@ -1079,6 +1079,8 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
|||
// Calculates TruthFulQA score (multiple choice with single correct answer) from prompt
|
||||
//
|
||||
// Data extracted from https://huggingface.co/datasets/truthful_qa
|
||||
// The validation dataset in the binary format that is being used can be found at
|
||||
// https://huggingface.co/datasets/ikawrakow/validation-datasets-for-llama.cpp
|
||||
//
|
||||
|
||||
std::istringstream strstream(params.prompt);
|
||||
|
@ -1207,6 +1209,7 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
|||
|
||||
int n_done = 0;
|
||||
int n_correct = 0;
|
||||
int n_tot_answers = 0;
|
||||
|
||||
for (size_t i0 = 0; i0 < tasks.size(); i0++) {
|
||||
int n_cur = 0;
|
||||
|
@ -1246,6 +1249,8 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
|||
}
|
||||
}
|
||||
|
||||
s0 += num_answers;
|
||||
|
||||
cur_task.i_batch = i_batch;
|
||||
i_batch += cur_task.required_tokens;
|
||||
|
||||
|
@ -1289,6 +1294,13 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
|||
// compute the logprobs for each ending of the decoded tasks
|
||||
for (size_t i = i0; i < i1; ++i) {
|
||||
auto & cur_task = tasks[i];
|
||||
//printf("==== Evaluating <%s> with correct answer ", cur_task.question.c_str());
|
||||
//for (int j = 0; j < int(cur_task.mc1.labels.size()); ++j) {
|
||||
// if (cur_task.mc1.labels[j] == 1) {
|
||||
// printf("%d", j+1);
|
||||
// }
|
||||
//}
|
||||
//printf("\n common_prefix: %zu\n", cur_task.common_prefix);
|
||||
|
||||
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(cur_task.i_batch + cur_task.common_prefix - 1), n_vocab*sizeof(float));
|
||||
|
||||
|
@ -1298,11 +1310,29 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
|||
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
||||
size_t count = 1;
|
||||
float log_prob = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
|
||||
//printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob);
|
||||
//for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
|
||||
// printf(" %zu %g\n", ir, eval_results[ir]);
|
||||
// ++count;
|
||||
// log_prob += eval_results[ir++];
|
||||
//}
|
||||
//size_t count = 0;
|
||||
//float log_prob = 0;
|
||||
//printf(" <%s>\n", cur_task.mc1.answers[s].c_str());
|
||||
//float log_prob = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
|
||||
//printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob);
|
||||
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
|
||||
//printf(" %zu %g\n", ir, eval_results[ir]);
|
||||
++count;
|
||||
log_prob += eval_results[ir++];
|
||||
}
|
||||
//if (!count) {
|
||||
// ++count;
|
||||
// log_prob += std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
|
||||
//}
|
||||
cur_task.log_probs[s] = log_prob / count;
|
||||
//printf(" Final: %g\n", log_prob / count);
|
||||
//printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count);
|
||||
}
|
||||
|
||||
// Find the ending with maximum logprob
|
||||
|
@ -1315,13 +1345,14 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
|||
}
|
||||
}
|
||||
|
||||
n_tot_answers += cur_task.log_probs.size();
|
||||
if (cur_task.mc1.labels[logprob_max_idx] == 1) {
|
||||
++n_correct;
|
||||
}
|
||||
++n_done;
|
||||
|
||||
// Print the accumulated accuracy mean x 100
|
||||
printf("%zu\t%.8lf\n", i + 1, 100.*n_correct/n_done);
|
||||
printf("%d\t%.8lf\n", n_done, 100.*n_correct/n_done);
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
|
@ -1330,6 +1361,15 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
|||
|
||||
llama_batch_free(batch);
|
||||
|
||||
if (n_done < 100) return;
|
||||
|
||||
float p = 1.f*n_correct/n_done;
|
||||
float sigma = sqrt(p*(1-p)/(n_done-1));
|
||||
printf("\n Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
|
||||
p = 1.f*n_done/n_tot_answers;
|
||||
sigma = sqrt(p*(1-p)/(n_done-1));
|
||||
printf("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
|
||||
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue