TruthfulQA: prepare tasks in parallel for large test datasets
This commit is contained in:
parent
21d0ce5e05
commit
d86f80f416
1 changed files with 85 additions and 39 deletions
|
@ -1074,6 +1074,48 @@ struct truthful_qa_task {
|
||||||
std::vector<float> log_probs;
|
std::vector<float> log_probs;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static bool truthful_qa_prepare_one_task(llama_context * ctx, bool add_bos, truthful_qa_task& task, bool log_error) {
|
||||||
|
if (task.question.empty() || task.mc1.answers.empty()) {
|
||||||
|
if (log_error) {
|
||||||
|
printf("%s: found bad task with empty question and/or answers\n", __func__);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
task.seq_tokens.reserve(task.mc1.answers.size());
|
||||||
|
for (auto& answer : task.mc1.answers) {
|
||||||
|
if (answer.empty()) {
|
||||||
|
if (log_error) {
|
||||||
|
printf("%s: found empty answer\n", __func__);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos));
|
||||||
|
}
|
||||||
|
auto min_len = task.seq_tokens.front().size();
|
||||||
|
for (auto& seq : task.seq_tokens) {
|
||||||
|
min_len = std::min(min_len, seq.size());
|
||||||
|
}
|
||||||
|
task.common_prefix = 0;
|
||||||
|
for (size_t k = 0; k < min_len; ++k) {
|
||||||
|
auto token = task.seq_tokens[0][k];
|
||||||
|
bool all_same = true;
|
||||||
|
for (size_t i = 1; i < task.seq_tokens.size(); ++i) {
|
||||||
|
if (task.seq_tokens[i][k] != token) {
|
||||||
|
all_same = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!all_same) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++task.common_prefix;
|
||||||
|
}
|
||||||
|
task.required_tokens = task.common_prefix;
|
||||||
|
for (auto& seq : task.seq_tokens) {
|
||||||
|
task.required_tokens += seq.size() - task.common_prefix;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
||||||
// Calculates TruthFulQA score (multiple choice with single correct answer) from prompt
|
// Calculates TruthFulQA score (multiple choice with single correct answer) from prompt
|
||||||
|
@ -1141,51 +1183,55 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
||||||
|
|
||||||
printf("%s: preparing task data", __func__);
|
printf("%s: preparing task data", __func__);
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
if (n_task > 500) {
|
||||||
|
printf("...");
|
||||||
|
fflush(stdout);
|
||||||
|
constexpr int k_chunk = 4;
|
||||||
|
std::atomic<int> counter(0);
|
||||||
|
std::atomic<int> n_bad(0);
|
||||||
|
auto prepare = [&counter, &n_bad, &tasks, ctx, add_bos] () {
|
||||||
|
int num_tasks = tasks.size();
|
||||||
|
int n_bad_local = 0;
|
||||||
|
while (true) {
|
||||||
|
int first = counter.fetch_add(k_chunk);
|
||||||
|
if (first >= num_tasks) {
|
||||||
|
if (n_bad_local > 0) n_bad += n_bad_local;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
int last = std::min(first + k_chunk, num_tasks);
|
||||||
|
for (int i = first; i < last; ++i) {
|
||||||
|
if (!truthful_qa_prepare_one_task(ctx, add_bos, tasks[i], false)) ++n_bad_local;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
size_t max_thread = std::thread::hardware_concurrency();
|
||||||
|
max_thread = std::min(max_thread, (tasks.size() + k_chunk - 1)/k_chunk);
|
||||||
|
std::vector<std::thread> workers(max_thread-1);
|
||||||
|
for (auto& w : workers) w = std::thread(prepare);
|
||||||
|
prepare();
|
||||||
|
for (auto& w : workers) w.join();
|
||||||
|
printf("done\n");
|
||||||
|
fflush(stdout);
|
||||||
|
int nbad = n_bad;
|
||||||
|
if (nbad > 0) {
|
||||||
|
printf("%s: found %d malformed tasks\n", __func__, nbad);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
int n_dot = n_task/100;
|
int n_dot = n_task/100;
|
||||||
int i_task = 0;
|
int i_task = 0;
|
||||||
for (auto& task : tasks) {
|
for (auto& task : tasks) {
|
||||||
++i_task;
|
++i_task;
|
||||||
if (task.question.empty() || task.mc1.answers.empty()) {
|
if (!truthful_qa_prepare_one_task(ctx, add_bos, task, true)) {
|
||||||
printf("%s: found bad task with empty question and/or answers\n", __func__);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
task.seq_tokens.reserve(task.mc1.answers.size());
|
|
||||||
for (auto& answer : task.mc1.answers) {
|
|
||||||
if (answer.empty()) {
|
|
||||||
printf("%s: found empty answer\n", __func__);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos));
|
|
||||||
}
|
|
||||||
auto min_len = task.seq_tokens.front().size();
|
|
||||||
for (auto& seq : task.seq_tokens) {
|
|
||||||
min_len = std::min(min_len, seq.size());
|
|
||||||
}
|
|
||||||
task.common_prefix = 0;
|
|
||||||
for (size_t k = 0; k < min_len; ++k) {
|
|
||||||
auto token = task.seq_tokens[0][k];
|
|
||||||
bool all_same = true;
|
|
||||||
for (size_t i = 1; i < task.seq_tokens.size(); ++i) {
|
|
||||||
if (task.seq_tokens[i][k] != token) {
|
|
||||||
all_same = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!all_same) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
++task.common_prefix;
|
|
||||||
}
|
|
||||||
task.required_tokens = task.common_prefix;
|
|
||||||
for (auto& seq : task.seq_tokens) {
|
|
||||||
task.required_tokens += seq.size() - task.common_prefix;
|
|
||||||
}
|
|
||||||
if (i_task%n_dot == 0) {
|
if (i_task%n_dot == 0) {
|
||||||
printf(".");
|
printf(".");
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("done\n");
|
printf("done\n");
|
||||||
|
}
|
||||||
|
|
||||||
printf("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());
|
printf("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue