Fix missing logit restoration step (?)
Does this matter, actually?
This commit is contained in:
parent
2e3b4f6237
commit
281e2bad8c
1 changed files with 10 additions and 2 deletions
|
@ -131,9 +131,14 @@ llama_token llama_sampling_sample(
|
||||||
|
|
||||||
// Get a pointer to the logits
|
// Get a pointer to the logits
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
|
||||||
|
// Declare original_logits at the beginning of the function scope
|
||||||
|
std::vector<float> original_logits;
|
||||||
|
|
||||||
// Make a copy of the original logits before any modifications
|
if (!is_resampling) {
|
||||||
std::vector<float> original_logits(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
|
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
|
||||||
|
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
|
||||||
|
}
|
||||||
|
|
||||||
// apply params.logit_bias map
|
// apply params.logit_bias map
|
||||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||||
|
@ -233,6 +238,9 @@ llama_token llama_sampling_sample(
|
||||||
if (!is_valid) {
|
if (!is_valid) {
|
||||||
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
||||||
|
|
||||||
|
// Restore logits from the copy
|
||||||
|
std::copy(original_logits.begin(), original_logits.end(), logits);
|
||||||
|
|
||||||
// Recursively call llama_sampling_sample to resample with the grammar checks applied first
|
// Recursively call llama_sampling_sample to resample with the grammar checks applied first
|
||||||
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
|
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue