added implementation of DRY sampler

This commit is contained in:
l3utterfly 2024-04-25 15:55:34 +09:00
parent 784e11dea1
commit f64dea0821
3 changed files with 32 additions and 1 deletions

View file

@ -267,13 +267,18 @@ static llama_token_data_array llama_sampling_prepare_impl(
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
// repetition penalties
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const bool penalize_nl = params.penalize_nl;
// DRY sampler parameters
const float dry_multiplier = params.dry_multiplier;
const float dry_base = params.dry_base;
const int dry_allowed_length = params.dry_allowed_length;
auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
@ -309,10 +314,20 @@ static llama_token_data_array llama_sampling_prepare_impl(
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
// repetition penalties
llama_sample_repetition_penalties(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
// DRY penalties (multiplier > 0 means enabled)
if(dry_multiplier > 0.0f) {
llama_sample_dry(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size());
}
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {

View file

@ -41,6 +41,9 @@ typedef struct llama_sampling_params {
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
float dry_base = 1.75f;
int dry_allowed_length = 2;
std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,
@ -61,6 +64,7 @@ typedef struct llama_sampling_params {
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
std::vector<llama_token> penalty_prompt_tokens;
std::vector<llama_token> dry_sequence_breakers; // sequence breakers for the DRY sampler
bool use_penalty_prompt_tokens = false;
} llama_sampling_params;

12
llama.h
View file

@ -924,6 +924,18 @@ extern "C" {
float p,
size_t min_keep);
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
LLAMA_API void llama_sample_dry(
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
int last_tokens_size,
float dry_base,
float dry_multiplier,
int dry_allowed_length,
const llama_token * seq_breakers,
int seq_breakers_size);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(
struct llama_context * ctx,