From 03755743cf3eff5b9bc448a7b7787c889ed7af97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ma=C3=ABl=20Kerbiriou?= Date: Thu, 16 Mar 2023 18:58:59 +0100 Subject: [PATCH] log distribution after prompt tokens --- main.cpp | 5 +++++ utils.cpp | 15 +++++++++++++++ utils.h | 6 ++++++ 3 files changed, 26 insertions(+) diff --git a/main.cpp b/main.cpp index 4d6e91826..d11d0b11a 100644 --- a/main.cpp +++ b/main.cpp @@ -1017,6 +1017,11 @@ int main(int argc, char ** argv) { // decrement remaining sampling budget --remaining_tokens; } else { + if(log_file) { + const int n_vocab = model.hparams.n_vocab; + const float temp = params.temp; + print_output(vocab, logits.data() + (logits.size() - n_vocab), temp); + } // some user input remains from prompt or interaction, forward it to processing while (embd_inp.size() > input_consumed) { embd.push_back(embd_inp[input_consumed]); diff --git a/utils.cpp b/utils.cpp index f1a73aa74..c83f2aaa1 100644 --- a/utils.cpp +++ b/utils.cpp @@ -649,6 +649,21 @@ gpt_vocab::id sample_top_k_top_p( return sampled_tok_id; } +gpt_vocab::id print_output( + const gpt_vocab & vocab, + const float * logits, + double temp) { + SoftMaxSampler probs; + probs.reset(vocab, logits, temp); + probs.top_k_sort(); + probs.soft_max(); + + probs.print(log_file, vocab, logits, 16); + + return probs.top(); +} + + size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist) { const int nb = k / qk; diff --git a/utils.h b/utils.h index 11609f801..2d7754b82 100644 --- a/utils.h +++ b/utils.h @@ -106,6 +106,12 @@ gpt_vocab::id sample_top_k_top_p( double temp, std::mt19937 & rng); +// Print would-be output after prompt samples +gpt_vocab::id print_output( + const gpt_vocab & vocab, + const float * logits, + double temp); + // // Quantization //