sampling : deduplicated code for probability distribution access (#6240)

* sampling: remove duplicated code for probability distribution access

* free original_logits

* fix original_logits allocation

* fixes based on review @cebtenzzre

* change function name to `llama_sampling_prepare`
This commit is contained in:
Minsoo Cheong 2024-03-24 17:54:07 +09:00 committed by GitHub
parent ddf6568510
commit 586e7bc561
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 28 additions and 76 deletions

View file

@ -131,12 +131,14 @@ llama_token llama_sampling_sample(
struct llama_context * ctx_cfg,
int idx = 0);
// returns the probability that token of given id will be sampled
llama_token_data_array llama_sampling_probability_distribution(
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
llama_token_data_array llama_sampling_prepare(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);
int idx = 0,
bool apply_grammar = true,
std::vector<float> * original_logits = nullptr);
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,