Fix missing logit restoration step (?)
Does this matter, actually?
This commit is contained in:
parent
2e3b4f6237
commit
281e2bad8c
1 changed files with 10 additions and 2 deletions
|
@ -131,9 +131,14 @@ llama_token llama_sampling_sample(
|
|||
|
||||
// Get a pointer to the logits
|
||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||
|
||||
// Declare original_logits at the beginning of the function scope
|
||||
std::vector<float> original_logits;
|
||||
|
||||
// Make a copy of the original logits before any modifications
|
||||
std::vector<float> original_logits(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
|
||||
if (!is_resampling) {
|
||||
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
|
||||
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
|
||||
}
|
||||
|
||||
// apply params.logit_bias map
|
||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||
|
@ -233,6 +238,9 @@ llama_token llama_sampling_sample(
|
|||
if (!is_valid) {
|
||||
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
||||
|
||||
// Restore logits from the copy
|
||||
std::copy(original_logits.begin(), original_logits.end(), logits);
|
||||
|
||||
// Recursively call llama_sampling_sample to resample with the grammar checks applied first
|
||||
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue