fix original_logits allocation
This commit is contained in:
parent
fddd201942
commit
27f2e85520
2 changed files with 9 additions and 12 deletions
|
@ -173,10 +173,10 @@ static llama_token llama_sampling_sample_impl(
|
||||||
const float mirostat_tau = params.mirostat_tau;
|
const float mirostat_tau = params.mirostat_tau;
|
||||||
const float mirostat_eta = params.mirostat_eta;
|
const float mirostat_eta = params.mirostat_eta;
|
||||||
|
|
||||||
std::vector<float>* original_logits = nullptr;
|
std::vector<float> original_logits = {};
|
||||||
auto cur_p = llama_sampling_configure_token_candidates(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
|
auto cur_p = llama_sampling_configure_token_candidates(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
|
||||||
if (!is_resampling) {
|
if (!is_resampling) {
|
||||||
GGML_ASSERT(original_logits != nullptr);
|
GGML_ASSERT(!original_logits.empty());
|
||||||
}
|
}
|
||||||
llama_token id = 0;
|
llama_token id = 0;
|
||||||
// Get a pointer to the logits
|
// Get a pointer to the logits
|
||||||
|
@ -236,14 +236,11 @@ static llama_token llama_sampling_sample_impl(
|
||||||
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
||||||
|
|
||||||
// Restore logits from the copy
|
// Restore logits from the copy
|
||||||
std::copy((*original_logits).begin(), (*original_logits).end(), logits);
|
std::copy(original_logits.begin(), original_logits.end(), logits);
|
||||||
|
|
||||||
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
|
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (original_logits != nullptr) {
|
|
||||||
delete original_logits;
|
|
||||||
}
|
|
||||||
|
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
@ -254,7 +251,7 @@ static llama_token_data_array llama_sampling_configure_token_candidates_impl(
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
const int idx,
|
const int idx,
|
||||||
bool apply_grammar,
|
bool apply_grammar,
|
||||||
std::vector<float>** original_logits) {
|
std::vector<float>* original_logits) {
|
||||||
const llama_sampling_params & params = ctx_sampling->params;
|
const llama_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
@ -263,7 +260,7 @@ static llama_token_data_array llama_sampling_configure_token_candidates_impl(
|
||||||
const float penalty_repeat = params.penalty_repeat;
|
const float penalty_repeat = params.penalty_repeat;
|
||||||
const float penalty_freq = params.penalty_freq;
|
const float penalty_freq = params.penalty_freq;
|
||||||
const float penalty_present = params.penalty_present;
|
const float penalty_present = params.penalty_present;
|
||||||
|
|
||||||
const bool penalize_nl = params.penalize_nl;
|
const bool penalize_nl = params.penalize_nl;
|
||||||
|
|
||||||
auto & prev = ctx_sampling->prev;
|
auto & prev = ctx_sampling->prev;
|
||||||
|
@ -272,9 +269,9 @@ static llama_token_data_array llama_sampling_configure_token_candidates_impl(
|
||||||
// Get a pointer to the logits
|
// Get a pointer to the logits
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
|
||||||
if (apply_grammar && original_logits != nullptr) {
|
if (apply_grammar && original_logits != NULL) {
|
||||||
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
|
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
|
||||||
*original_logits = new std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
|
(*original_logits).insert((*original_logits).end(), logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply params.logit_bias map
|
// apply params.logit_bias map
|
||||||
|
@ -338,7 +335,7 @@ llama_token_data_array llama_sampling_configure_token_candidates(
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
const int idx,
|
const int idx,
|
||||||
bool apply_grammar,
|
bool apply_grammar,
|
||||||
std::vector<float>** original_logits) {
|
std::vector<float>* original_logits) {
|
||||||
return llama_sampling_configure_token_candidates_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
|
return llama_sampling_configure_token_candidates_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -138,7 +138,7 @@ llama_token_data_array llama_sampling_configure_token_candidates(
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
int idx = 0,
|
int idx = 0,
|
||||||
bool apply_grammar = true,
|
bool apply_grammar = true,
|
||||||
std::vector<float>** original_logits = nullptr);
|
std::vector<float>* original_logits = nullptr);
|
||||||
|
|
||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue