TruthfulQA: prepare tasks in parallel for large test datasets
This commit is contained in:
parent
21d0ce5e05
commit
d86f80f416
1 changed files with 85 additions and 39 deletions
|
@ -1074,6 +1074,48 @@ struct truthful_qa_task {
|
|||
std::vector<float> log_probs;
|
||||
};
|
||||
|
||||
static bool truthful_qa_prepare_one_task(llama_context * ctx, bool add_bos, truthful_qa_task& task, bool log_error) {
|
||||
if (task.question.empty() || task.mc1.answers.empty()) {
|
||||
if (log_error) {
|
||||
printf("%s: found bad task with empty question and/or answers\n", __func__);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
task.seq_tokens.reserve(task.mc1.answers.size());
|
||||
for (auto& answer : task.mc1.answers) {
|
||||
if (answer.empty()) {
|
||||
if (log_error) {
|
||||
printf("%s: found empty answer\n", __func__);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos));
|
||||
}
|
||||
auto min_len = task.seq_tokens.front().size();
|
||||
for (auto& seq : task.seq_tokens) {
|
||||
min_len = std::min(min_len, seq.size());
|
||||
}
|
||||
task.common_prefix = 0;
|
||||
for (size_t k = 0; k < min_len; ++k) {
|
||||
auto token = task.seq_tokens[0][k];
|
||||
bool all_same = true;
|
||||
for (size_t i = 1; i < task.seq_tokens.size(); ++i) {
|
||||
if (task.seq_tokens[i][k] != token) {
|
||||
all_same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!all_same) {
|
||||
break;
|
||||
}
|
||||
++task.common_prefix;
|
||||
}
|
||||
task.required_tokens = task.common_prefix;
|
||||
for (auto& seq : task.seq_tokens) {
|
||||
task.required_tokens += seq.size() - task.common_prefix;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
||||
// Calculates TruthFulQA score (multiple choice with single correct answer) from prompt
|
||||
|
@ -1141,51 +1183,55 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
|||
|
||||
printf("%s: preparing task data", __func__);
|
||||
fflush(stdout);
|
||||
if (n_task > 500) {
|
||||
printf("...");
|
||||
fflush(stdout);
|
||||
constexpr int k_chunk = 4;
|
||||
std::atomic<int> counter(0);
|
||||
std::atomic<int> n_bad(0);
|
||||
auto prepare = [&counter, &n_bad, &tasks, ctx, add_bos] () {
|
||||
int num_tasks = tasks.size();
|
||||
int n_bad_local = 0;
|
||||
while (true) {
|
||||
int first = counter.fetch_add(k_chunk);
|
||||
if (first >= num_tasks) {
|
||||
if (n_bad_local > 0) n_bad += n_bad_local;
|
||||
break;
|
||||
}
|
||||
int last = std::min(first + k_chunk, num_tasks);
|
||||
for (int i = first; i < last; ++i) {
|
||||
if (!truthful_qa_prepare_one_task(ctx, add_bos, tasks[i], false)) ++n_bad_local;
|
||||
}
|
||||
}
|
||||
};
|
||||
size_t max_thread = std::thread::hardware_concurrency();
|
||||
max_thread = std::min(max_thread, (tasks.size() + k_chunk - 1)/k_chunk);
|
||||
std::vector<std::thread> workers(max_thread-1);
|
||||
for (auto& w : workers) w = std::thread(prepare);
|
||||
prepare();
|
||||
for (auto& w : workers) w.join();
|
||||
printf("done\n");
|
||||
fflush(stdout);
|
||||
int nbad = n_bad;
|
||||
if (nbad > 0) {
|
||||
printf("%s: found %d malformed tasks\n", __func__, nbad);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
int n_dot = n_task/100;
|
||||
int i_task = 0;
|
||||
for (auto& task : tasks) {
|
||||
++i_task;
|
||||
if (task.question.empty() || task.mc1.answers.empty()) {
|
||||
printf("%s: found bad task with empty question and/or answers\n", __func__);
|
||||
if (!truthful_qa_prepare_one_task(ctx, add_bos, task, true)) {
|
||||
return;
|
||||
}
|
||||
task.seq_tokens.reserve(task.mc1.answers.size());
|
||||
for (auto& answer : task.mc1.answers) {
|
||||
if (answer.empty()) {
|
||||
printf("%s: found empty answer\n", __func__);
|
||||
return;
|
||||
}
|
||||
task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos));
|
||||
}
|
||||
auto min_len = task.seq_tokens.front().size();
|
||||
for (auto& seq : task.seq_tokens) {
|
||||
min_len = std::min(min_len, seq.size());
|
||||
}
|
||||
task.common_prefix = 0;
|
||||
for (size_t k = 0; k < min_len; ++k) {
|
||||
auto token = task.seq_tokens[0][k];
|
||||
bool all_same = true;
|
||||
for (size_t i = 1; i < task.seq_tokens.size(); ++i) {
|
||||
if (task.seq_tokens[i][k] != token) {
|
||||
all_same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!all_same) {
|
||||
break;
|
||||
}
|
||||
++task.common_prefix;
|
||||
}
|
||||
task.required_tokens = task.common_prefix;
|
||||
for (auto& seq : task.seq_tokens) {
|
||||
task.required_tokens += seq.size() - task.common_prefix;
|
||||
}
|
||||
if (i_task%n_dot == 0) {
|
||||
printf(".");
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
printf("done\n");
|
||||
}
|
||||
|
||||
printf("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue