diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index c23eaefab..e32f86130 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1074,6 +1074,48 @@ struct truthful_qa_task { std::vector log_probs; }; +static bool truthful_qa_prepare_one_task(llama_context * ctx, bool add_bos, truthful_qa_task& task, bool log_error) { + if (task.question.empty() || task.mc1.answers.empty()) { + if (log_error) { + printf("%s: found bad task with empty question and/or answers\n", __func__); + } + return false; + } + task.seq_tokens.reserve(task.mc1.answers.size()); + for (auto& answer : task.mc1.answers) { + if (answer.empty()) { + if (log_error) { + printf("%s: found empty answer\n", __func__); + } + return false; + } + task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos)); + } + auto min_len = task.seq_tokens.front().size(); + for (auto& seq : task.seq_tokens) { + min_len = std::min(min_len, seq.size()); + } + task.common_prefix = 0; + for (size_t k = 0; k < min_len; ++k) { + auto token = task.seq_tokens[0][k]; + bool all_same = true; + for (size_t i = 1; i < task.seq_tokens.size(); ++i) { + if (task.seq_tokens[i][k] != token) { + all_same = false; + break; + } + } + if (!all_same) { + break; + } + ++task.common_prefix; + } + task.required_tokens = task.common_prefix; + for (auto& seq : task.seq_tokens) { + task.required_tokens += seq.size() - task.common_prefix; + } + return true; +} static void truthful_qa_score(llama_context * ctx, const gpt_params & params) { // Calculates TruthFulQA score (multiple choice with single correct answer) from prompt @@ -1141,51 +1183,55 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) { printf("%s: preparing task data", __func__); fflush(stdout); - int n_dot = n_task/100; - int i_task = 0; - for (auto& task : tasks) { - ++i_task; - if (task.question.empty() || task.mc1.answers.empty()) { - printf("%s: found bad task with empty question and/or answers\n", __func__); - return; - } - task.seq_tokens.reserve(task.mc1.answers.size()); - for (auto& answer : task.mc1.answers) { - if (answer.empty()) { - printf("%s: found empty answer\n", __func__); - return; - } - task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos)); - } - auto min_len = task.seq_tokens.front().size(); - for (auto& seq : task.seq_tokens) { - min_len = std::min(min_len, seq.size()); - } - task.common_prefix = 0; - for (size_t k = 0; k < min_len; ++k) { - auto token = task.seq_tokens[0][k]; - bool all_same = true; - for (size_t i = 1; i < task.seq_tokens.size(); ++i) { - if (task.seq_tokens[i][k] != token) { - all_same = false; + if (n_task > 500) { + printf("..."); + fflush(stdout); + constexpr int k_chunk = 4; + std::atomic counter(0); + std::atomic n_bad(0); + auto prepare = [&counter, &n_bad, &tasks, ctx, add_bos] () { + int num_tasks = tasks.size(); + int n_bad_local = 0; + while (true) { + int first = counter.fetch_add(k_chunk); + if (first >= num_tasks) { + if (n_bad_local > 0) n_bad += n_bad_local; break; } + int last = std::min(first + k_chunk, num_tasks); + for (int i = first; i < last; ++i) { + if (!truthful_qa_prepare_one_task(ctx, add_bos, tasks[i], false)) ++n_bad_local; + } } - if (!all_same) { - break; + }; + size_t max_thread = std::thread::hardware_concurrency(); + max_thread = std::min(max_thread, (tasks.size() + k_chunk - 1)/k_chunk); + std::vector workers(max_thread-1); + for (auto& w : workers) w = std::thread(prepare); + prepare(); + for (auto& w : workers) w.join(); + printf("done\n"); + fflush(stdout); + int nbad = n_bad; + if (nbad > 0) { + printf("%s: found %d malformed tasks\n", __func__, nbad); + return; + } + } else { + int n_dot = n_task/100; + int i_task = 0; + for (auto& task : tasks) { + ++i_task; + if (!truthful_qa_prepare_one_task(ctx, add_bos, task, true)) { + return; + } + if (i_task%n_dot == 0) { + printf("."); + fflush(stdout); } - ++task.common_prefix; - } - task.required_tokens = task.common_prefix; - for (auto& seq : task.seq_tokens) { - task.required_tokens += seq.size() - task.common_prefix; - } - if (i_task%n_dot == 0) { - printf("."); - fflush(stdout); } + printf("done\n"); } - printf("done\n"); printf("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());