From 281e2bad8c0ab087dbf1f6f307f60e9173a371d7 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sun, 3 Dec 2023 05:25:44 -0600 Subject: [PATCH] Fix missing logit restoration step (?) Does this matter, actually? --- common/sampling.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 78092611b..d87340d2b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -131,9 +131,14 @@ llama_token llama_sampling_sample( // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); + + // Declare original_logits at the beginning of the function scope + std::vector original_logits; - // Make a copy of the original logits before any modifications - std::vector original_logits(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); + if (!is_resampling) { + // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this. + original_logits = std::vector(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); + } // apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { @@ -233,6 +238,9 @@ llama_token llama_sampling_sample( if (!is_valid) { LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str()); + // Restore logits from the copy + std::copy(original_logits.begin(), original_logits.end(), logits); + // Recursively call llama_sampling_sample to resample with the grammar checks applied first return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling }