fixes based on review @cebtenzzre

This commit is contained in:
Minsoo Cheong 2024-03-23 00:24:01 +09:00
parent 27f2e85520
commit 0a243da7d4
2 changed files with 5 additions and 5 deletions

View file

@ -173,7 +173,7 @@ static llama_token llama_sampling_sample_impl(
const float mirostat_tau = params.mirostat_tau; const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta; const float mirostat_eta = params.mirostat_eta;
std::vector<float> original_logits = {}; std::vector<float> original_logits;
auto cur_p = llama_sampling_configure_token_candidates(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits); auto cur_p = llama_sampling_configure_token_candidates(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
if (!is_resampling) { if (!is_resampling) {
GGML_ASSERT(!original_logits.empty()); GGML_ASSERT(!original_logits.empty());
@ -251,7 +251,7 @@ static llama_token_data_array llama_sampling_configure_token_candidates_impl(
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
const int idx, const int idx,
bool apply_grammar, bool apply_grammar,
std::vector<float>* original_logits) { std::vector<float> * original_logits) {
const llama_sampling_params & params = ctx_sampling->params; const llama_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@ -271,7 +271,7 @@ static llama_token_data_array llama_sampling_configure_token_candidates_impl(
if (apply_grammar && original_logits != NULL) { if (apply_grammar && original_logits != NULL) {
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
(*original_logits).insert((*original_logits).end(), logits, logits + llama_n_vocab(llama_get_model(ctx_main))); *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
} }
// apply params.logit_bias map // apply params.logit_bias map
@ -335,7 +335,7 @@ llama_token_data_array llama_sampling_configure_token_candidates(
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
const int idx, const int idx,
bool apply_grammar, bool apply_grammar,
std::vector<float>* original_logits) { std::vector<float> * original_logits) {
return llama_sampling_configure_token_candidates_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits); return llama_sampling_configure_token_candidates_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
} }

View file

@ -138,7 +138,7 @@ llama_token_data_array llama_sampling_configure_token_candidates(
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
int idx = 0, int idx = 0,
bool apply_grammar = true, bool apply_grammar = true,
std::vector<float>* original_logits = nullptr); std::vector<float> * original_logits = nullptr);
void llama_sampling_accept( void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,