Fix missing logit restoration step (?)

Does this matter, actually?
This commit is contained in:
kalomaze 2023-12-03 05:25:44 -06:00
parent 2e3b4f6237
commit 281e2bad8c

View file

@ -131,9 +131,14 @@ llama_token llama_sampling_sample(
// Get a pointer to the logits // Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx); float * logits = llama_get_logits_ith(ctx_main, idx);
// Declare original_logits at the beginning of the function scope
std::vector<float> original_logits;
// Make a copy of the original logits before any modifications if (!is_resampling) {
std::vector<float> original_logits(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
}
// apply params.logit_bias map // apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
@ -233,6 +238,9 @@ llama_token llama_sampling_sample(
if (!is_valid) { if (!is_valid) {
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str()); LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
// Restore logits from the copy
std::copy(original_logits.begin(), original_logits.end(), logits);
// Recursively call llama_sampling_sample to resample with the grammar checks applied first // Recursively call llama_sampling_sample to resample with the grammar checks applied first
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
} }