TruthfulQA: prepare tasks in parallel for large test datasets

This commit is contained in:
Iwan Kawrakow 2024-01-20 11:44:55 +02:00
parent 21d0ce5e05
commit d86f80f416

View file

@ -1074,6 +1074,48 @@ struct truthful_qa_task {
std::vector<float> log_probs; std::vector<float> log_probs;
}; };
static bool truthful_qa_prepare_one_task(llama_context * ctx, bool add_bos, truthful_qa_task& task, bool log_error) {
if (task.question.empty() || task.mc1.answers.empty()) {
if (log_error) {
printf("%s: found bad task with empty question and/or answers\n", __func__);
}
return false;
}
task.seq_tokens.reserve(task.mc1.answers.size());
for (auto& answer : task.mc1.answers) {
if (answer.empty()) {
if (log_error) {
printf("%s: found empty answer\n", __func__);
}
return false;
}
task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos));
}
auto min_len = task.seq_tokens.front().size();
for (auto& seq : task.seq_tokens) {
min_len = std::min(min_len, seq.size());
}
task.common_prefix = 0;
for (size_t k = 0; k < min_len; ++k) {
auto token = task.seq_tokens[0][k];
bool all_same = true;
for (size_t i = 1; i < task.seq_tokens.size(); ++i) {
if (task.seq_tokens[i][k] != token) {
all_same = false;
break;
}
}
if (!all_same) {
break;
}
++task.common_prefix;
}
task.required_tokens = task.common_prefix;
for (auto& seq : task.seq_tokens) {
task.required_tokens += seq.size() - task.common_prefix;
}
return true;
}
static void truthful_qa_score(llama_context * ctx, const gpt_params & params) { static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
// Calculates TruthFulQA score (multiple choice with single correct answer) from prompt // Calculates TruthFulQA score (multiple choice with single correct answer) from prompt
@ -1141,51 +1183,55 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
printf("%s: preparing task data", __func__); printf("%s: preparing task data", __func__);
fflush(stdout); fflush(stdout);
if (n_task > 500) {
printf("...");
fflush(stdout);
constexpr int k_chunk = 4;
std::atomic<int> counter(0);
std::atomic<int> n_bad(0);
auto prepare = [&counter, &n_bad, &tasks, ctx, add_bos] () {
int num_tasks = tasks.size();
int n_bad_local = 0;
while (true) {
int first = counter.fetch_add(k_chunk);
if (first >= num_tasks) {
if (n_bad_local > 0) n_bad += n_bad_local;
break;
}
int last = std::min(first + k_chunk, num_tasks);
for (int i = first; i < last; ++i) {
if (!truthful_qa_prepare_one_task(ctx, add_bos, tasks[i], false)) ++n_bad_local;
}
}
};
size_t max_thread = std::thread::hardware_concurrency();
max_thread = std::min(max_thread, (tasks.size() + k_chunk - 1)/k_chunk);
std::vector<std::thread> workers(max_thread-1);
for (auto& w : workers) w = std::thread(prepare);
prepare();
for (auto& w : workers) w.join();
printf("done\n");
fflush(stdout);
int nbad = n_bad;
if (nbad > 0) {
printf("%s: found %d malformed tasks\n", __func__, nbad);
return;
}
} else {
int n_dot = n_task/100; int n_dot = n_task/100;
int i_task = 0; int i_task = 0;
for (auto& task : tasks) { for (auto& task : tasks) {
++i_task; ++i_task;
if (task.question.empty() || task.mc1.answers.empty()) { if (!truthful_qa_prepare_one_task(ctx, add_bos, task, true)) {
printf("%s: found bad task with empty question and/or answers\n", __func__);
return; return;
} }
task.seq_tokens.reserve(task.mc1.answers.size());
for (auto& answer : task.mc1.answers) {
if (answer.empty()) {
printf("%s: found empty answer\n", __func__);
return;
}
task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos));
}
auto min_len = task.seq_tokens.front().size();
for (auto& seq : task.seq_tokens) {
min_len = std::min(min_len, seq.size());
}
task.common_prefix = 0;
for (size_t k = 0; k < min_len; ++k) {
auto token = task.seq_tokens[0][k];
bool all_same = true;
for (size_t i = 1; i < task.seq_tokens.size(); ++i) {
if (task.seq_tokens[i][k] != token) {
all_same = false;
break;
}
}
if (!all_same) {
break;
}
++task.common_prefix;
}
task.required_tokens = task.common_prefix;
for (auto& seq : task.seq_tokens) {
task.required_tokens += seq.size() - task.common_prefix;
}
if (i_task%n_dot == 0) { if (i_task%n_dot == 0) {
printf("."); printf(".");
fflush(stdout); fflush(stdout);
} }
} }
printf("done\n"); printf("done\n");
}
printf("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size()); printf("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());