sampler log function

This commit is contained in:
Maël Kerbiriou 2023-03-16 18:12:17 +01:00
parent abbf7e7a61
commit aa6c2bd5d2

View file

@ -591,6 +591,29 @@ struct SoftMaxSampler {
) const { ) const {
return logits_id[dist(rng)].second; return logits_id[dist(rng)].second;
} }
void print(FILE* log_file, const gpt_vocab & vocab, const float * logits, int max_print, int selected=-1) const {
if (log_file == nullptr) {
return;
}
int n = probs.size();
if (n > max_print) {
n = max_print;
}
for (int i = 0; i < n; i++) {
const auto& entry = logits_id[i];
const int id = entry.second;
const double scaled_logit = entry.first;
fprintf(log_file, "%s%d: '%s' p=%f act=%.3f temp=%.2f\n",
selected >= 0 && id == selected ? "->" : " ",
i,
vocab.id_to_token.at(id).c_str(),
probs[i],
logits[id],
logits[id] / scaled_logit
);
}
}
}; };
gpt_vocab::id sample_top_k_top_p( gpt_vocab::id sample_top_k_top_p(