From fe5e95bb02e08b8f6eeef90479949b98e23f329b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 22 Jan 2024 06:48:56 +0200 Subject: [PATCH] kl-divergence: be able to save all logits to a file --- common/common.cpp | 6 ++ common/common.h | 1 + examples/perplexity/perplexity.cpp | 98 +++++++++++++++++++++++++++++- 3 files changed, 103 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 0e4b8bab2..f2f9c074d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -672,6 +672,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { if (params.logdir.back() != DIRECTORY_SEPARATOR) { params.logdir += DIRECTORY_SEPARATOR; } + } else if (arg == "--save-all-logits") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.logits_file = argv[i]; } else if (arg == "--perplexity" || arg == "--all-logits") { params.logits_all = true; } else if (arg == "--ppl-stride") { diff --git a/common/common.h b/common/common.h index c69ad7e94..2128e360a 100644 --- a/common/common.h +++ b/common/common.h @@ -91,6 +91,7 @@ struct gpt_params { std::string input_suffix = ""; // string to suffix user inputs with std::vector antiprompt; // string upon seeing which more user input is prompted std::string logdir = ""; // directory in which to save YAML log files + std::string logits_file = ""; // file for saving *all* logits std::vector kv_overrides; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index b7ef9a084..224827058 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -112,6 +112,43 @@ static results_log_softmax log_softmax(int n_vocab, const float * logits, int to return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp}; } +static inline int nearest_int(float fval) { + //assert(fval <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob, int tok) { + float max_logit = logits[0]; + float min_logit = logits[0]; + for (int i = 1; i < n_vocab; ++i) { + max_logit = std::max(max_logit, logits[i]); + min_logit = std::min(min_logit, logits[i]); + } + min_logit = std::max(min_logit, max_logit - 16); + double sum_exp = 0.0; + for (int i = 0; i < n_vocab; ++i) { + sum_exp += expf(logits[i] - max_logit); + } + const float log_sum_exp = log(sum_exp); + const float min_log_prob = min_logit - max_logit - log_sum_exp; + const float scale = (max_logit - min_logit)/65535.f; + float * d = (float *)log_prob; + d[0] = scale; + d[1] = min_log_prob; + log_prob += 4; + if (scale) { + const float inv_scale = 1/scale; + for (int i = 0; i < n_vocab; ++i) { + log_prob[i] = logits[i] > min_logit ? nearest_int(inv_scale*(logits[i] - min_logit)) : 0; + } + } else { + std::memset(log_prob, 0, n_vocab*sizeof(uint16_t)); + } + return max_logit + log_sum_exp - logits[tok]; +} + static void process_logits( int n_vocab, const float * logits, const int * tokens, int n_token, std::vector & workers, double & nll, double & nll2, float * logit_history, float * prob_history @@ -147,6 +184,37 @@ static void process_logits( } } +static void process_logits(std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token, + std::vector & workers, std::vector & log_probs, double & nll, double & nll2) { + std::mutex mutex; + const int nv = 2*((n_vocab + 1)/2) + 4; + int counter = 0; + auto compute = [&mutex, &counter, &log_probs, &nll, &nll2, n_vocab, logits, tokens, n_token, nv] () { + double local_nll = 0; + double local_nll2 = 0; + while (true) { + std::unique_lock lock(mutex); + int i = counter++; + if (i >= n_token) { + nll += local_nll; nll2 += local_nll2; + break; + } + lock.unlock(); + const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]); + local_nll += v; + local_nll2 += v*v; + } + }; + for (auto & w : workers) { + w = std::thread(compute); + } + compute(); + for (auto & w : workers) { + w.join(); + } + out.write((const char *)log_probs.data(), n_token*nv*sizeof(uint16_t)); +} + static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) { // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` @@ -294,6 +362,18 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const int n_ctx = llama_n_ctx(ctx); + std::ofstream logits_stream; + if (!params.logits_file.empty()) { + logits_stream.open(params.logits_file.c_str()); + if (!logits_stream.is_open()) { + fprintf(stderr, "%s: failed to open %s for writing\n", __func__, params.logits_file.c_str()); + return {}; + } + fprintf(stderr, "%s: saving all logits to %s\n", __func__, params.logits_file.c_str()); + logits_stream.write("_logits_", 8); + logits_stream.write((const char *)&n_ctx, sizeof(n_ctx)); + } + auto tim1 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenizing the input ..\n", __func__); @@ -336,6 +416,15 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par std::vector workers(std::thread::hardware_concurrency() - 1); + std::vector log_probs; + if (!params.logits_file.empty()) { + logits_stream.write((const char *)&n_vocab, sizeof(n_vocab)); + logits_stream.write((const char *)&n_chunk, sizeof(n_chunk)); + logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0])); + const int nv = 2*((n_vocab + 1)/2) + 4; + log_probs.resize(n_ctx * nv); + } + for (int i = 0; i < n_chunk; ++i) { const int start = i * n_ctx; const int end = start + n_ctx; @@ -398,8 +487,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // process the entire prompt. const int first = n_ctx/2; const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); - process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); + if (!params.logits_file.empty()) { + process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, + workers, log_probs, nll, nll2); + } else { + process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, + workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); + } count += n_ctx - first - 1; // perplexity is e^(average negative log-likelihood)