log distribution after prompt tokens

This commit is contained in:
Maël Kerbiriou 2023-03-16 18:58:59 +01:00
parent 4547848743
commit 03755743cf
3 changed files with 26 additions and 0 deletions

View file

@ -1017,6 +1017,11 @@ int main(int argc, char ** argv) {
// decrement remaining sampling budget // decrement remaining sampling budget
--remaining_tokens; --remaining_tokens;
} else { } else {
if(log_file) {
const int n_vocab = model.hparams.n_vocab;
const float temp = params.temp;
print_output(vocab, logits.data() + (logits.size() - n_vocab), temp);
}
// some user input remains from prompt or interaction, forward it to processing // some user input remains from prompt or interaction, forward it to processing
while (embd_inp.size() > input_consumed) { while (embd_inp.size() > input_consumed) {
embd.push_back(embd_inp[input_consumed]); embd.push_back(embd_inp[input_consumed]);

View file

@ -649,6 +649,21 @@ gpt_vocab::id sample_top_k_top_p(
return sampled_tok_id; return sampled_tok_id;
} }
gpt_vocab::id print_output(
const gpt_vocab & vocab,
const float * logits,
double temp) {
SoftMaxSampler probs;
probs.reset(vocab, logits, temp);
probs.top_k_sort();
probs.soft_max();
probs.print(log_file, vocab, logits, 16);
return probs.top();
}
size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist) { size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist) {
const int nb = k / qk; const int nb = k / qk;

View file

@ -106,6 +106,12 @@ gpt_vocab::id sample_top_k_top_p(
double temp, double temp,
std::mt19937 & rng); std::mt19937 & rng);
// Print would-be output after prompt samples
gpt_vocab::id print_output(
const gpt_vocab & vocab,
const float * logits,
double temp);
// //
// Quantization // Quantization
// //