TruthfulQA: fix random sample
This commit is contained in:
parent
b0a4873697
commit
21d0ce5e05
1 changed files with 1 additions and 15 deletions
|
@ -1119,6 +1119,7 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
||||||
printf("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.thruthful_qa_tasks, n_task);
|
printf("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.thruthful_qa_tasks, n_task);
|
||||||
std::mt19937 rng(1);
|
std::mt19937 rng(1);
|
||||||
std::vector<int> aux(n_task);
|
std::vector<int> aux(n_task);
|
||||||
|
for (uint32_t i = 0; i < n_task; ++i) aux[i] = i;
|
||||||
float scale = 1.f/(1.f + (float)std::mt19937::max());
|
float scale = 1.f/(1.f + (float)std::mt19937::max());
|
||||||
tasks.resize(params.thruthful_qa_tasks);
|
tasks.resize(params.thruthful_qa_tasks);
|
||||||
for (auto& task : tasks) {
|
for (auto& task : tasks) {
|
||||||
|
@ -1310,26 +1311,11 @@ static void truthful_qa_score(llama_context * ctx, const gpt_params & params) {
|
||||||
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
||||||
size_t count = 1;
|
size_t count = 1;
|
||||||
float log_prob = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
|
float log_prob = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
|
||||||
//printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob);
|
|
||||||
//for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
|
|
||||||
// printf(" %zu %g\n", ir, eval_results[ir]);
|
|
||||||
// ++count;
|
|
||||||
// log_prob += eval_results[ir++];
|
|
||||||
//}
|
|
||||||
//size_t count = 0;
|
|
||||||
//float log_prob = 0;
|
|
||||||
//printf(" <%s>\n", cur_task.mc1.answers[s].c_str());
|
|
||||||
//float log_prob = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
|
|
||||||
//printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob);
|
|
||||||
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
|
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
|
||||||
//printf(" %zu %g\n", ir, eval_results[ir]);
|
//printf(" %zu %g\n", ir, eval_results[ir]);
|
||||||
++count;
|
++count;
|
||||||
log_prob += eval_results[ir++];
|
log_prob += eval_results[ir++];
|
||||||
}
|
}
|
||||||
//if (!count) {
|
|
||||||
// ++count;
|
|
||||||
// log_prob += std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
|
|
||||||
//}
|
|
||||||
cur_task.log_probs[s] = log_prob / count;
|
cur_task.log_probs[s] = log_prob / count;
|
||||||
//printf(" Final: %g\n", log_prob / count);
|
//printf(" Final: %g\n", log_prob / count);
|
||||||
//printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count);
|
//printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue