added implementation of DRY sampler
This commit is contained in:
parent
784e11dea1
commit
f64dea0821
3 changed files with 32 additions and 1 deletions
|
@ -267,13 +267,18 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
|
||||
// repetition penalties
|
||||
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
||||
const float penalty_repeat = params.penalty_repeat;
|
||||
const float penalty_freq = params.penalty_freq;
|
||||
const float penalty_present = params.penalty_present;
|
||||
|
||||
const bool penalize_nl = params.penalize_nl;
|
||||
|
||||
// DRY sampler parameters
|
||||
const float dry_multiplier = params.dry_multiplier;
|
||||
const float dry_base = params.dry_base;
|
||||
const int dry_allowed_length = params.dry_allowed_length;
|
||||
|
||||
auto & prev = ctx_sampling->prev;
|
||||
auto & cur = ctx_sampling->cur;
|
||||
|
||||
|
@ -309,10 +314,20 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
if (penalty_tokens_used_size) {
|
||||
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
||||
|
||||
// repetition penalties
|
||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
||||
|
||||
// DRY penalties (multiplier > 0 means enabled)
|
||||
if(dry_multiplier > 0.0f) {
|
||||
llama_sample_dry(ctx_main, &cur_p,
|
||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
|
||||
params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size());
|
||||
}
|
||||
|
||||
|
||||
if (!penalize_nl) {
|
||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
||||
|
|
|
@ -41,6 +41,9 @@ typedef struct llama_sampling_params {
|
|||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
|
||||
float dry_base = 1.75f;
|
||||
int dry_allowed_length = 2;
|
||||
|
||||
std::vector<llama_sampler_type> samplers_sequence = {
|
||||
llama_sampler_type::TOP_K,
|
||||
|
@ -61,6 +64,7 @@ typedef struct llama_sampling_params {
|
|||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||
|
||||
std::vector<llama_token> penalty_prompt_tokens;
|
||||
std::vector<llama_token> dry_sequence_breakers; // sequence breakers for the DRY sampler
|
||||
bool use_penalty_prompt_tokens = false;
|
||||
} llama_sampling_params;
|
||||
|
||||
|
|
12
llama.h
12
llama.h
|
@ -924,6 +924,18 @@ extern "C" {
|
|||
float p,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
|
||||
LLAMA_API void llama_sample_dry(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
int last_tokens_size,
|
||||
float dry_base,
|
||||
float dry_multiplier,
|
||||
int dry_allowed_length,
|
||||
const llama_token * seq_breakers,
|
||||
int seq_breakers_size);
|
||||
|
||||
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
LLAMA_API void llama_sample_tail_free(
|
||||
struct llama_context * ctx,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue