added implementation of DRY sampler
This commit is contained in:
parent
784e11dea1
commit
f64dea0821
3 changed files with 32 additions and 1 deletions
|
@ -267,13 +267,18 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
|
||||||
|
// repetition penalties
|
||||||
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
||||||
const float penalty_repeat = params.penalty_repeat;
|
const float penalty_repeat = params.penalty_repeat;
|
||||||
const float penalty_freq = params.penalty_freq;
|
const float penalty_freq = params.penalty_freq;
|
||||||
const float penalty_present = params.penalty_present;
|
const float penalty_present = params.penalty_present;
|
||||||
|
|
||||||
const bool penalize_nl = params.penalize_nl;
|
const bool penalize_nl = params.penalize_nl;
|
||||||
|
|
||||||
|
// DRY sampler parameters
|
||||||
|
const float dry_multiplier = params.dry_multiplier;
|
||||||
|
const float dry_base = params.dry_base;
|
||||||
|
const int dry_allowed_length = params.dry_allowed_length;
|
||||||
|
|
||||||
auto & prev = ctx_sampling->prev;
|
auto & prev = ctx_sampling->prev;
|
||||||
auto & cur = ctx_sampling->cur;
|
auto & cur = ctx_sampling->cur;
|
||||||
|
|
||||||
|
@ -309,10 +314,20 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
if (penalty_tokens_used_size) {
|
if (penalty_tokens_used_size) {
|
||||||
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
||||||
|
|
||||||
|
// repetition penalties
|
||||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||||
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
||||||
|
|
||||||
|
// DRY penalties (multiplier > 0 means enabled)
|
||||||
|
if(dry_multiplier > 0.0f) {
|
||||||
|
llama_sample_dry(ctx_main, &cur_p,
|
||||||
|
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||||
|
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
|
||||||
|
params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if (!penalize_nl) {
|
if (!penalize_nl) {
|
||||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||||
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
||||||
|
|
|
@ -41,6 +41,9 @@ typedef struct llama_sampling_params {
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||||
|
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
|
||||||
|
float dry_base = 1.75f;
|
||||||
|
int dry_allowed_length = 2;
|
||||||
|
|
||||||
std::vector<llama_sampler_type> samplers_sequence = {
|
std::vector<llama_sampler_type> samplers_sequence = {
|
||||||
llama_sampler_type::TOP_K,
|
llama_sampler_type::TOP_K,
|
||||||
|
@ -61,6 +64,7 @@ typedef struct llama_sampling_params {
|
||||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||||
|
|
||||||
std::vector<llama_token> penalty_prompt_tokens;
|
std::vector<llama_token> penalty_prompt_tokens;
|
||||||
|
std::vector<llama_token> dry_sequence_breakers; // sequence breakers for the DRY sampler
|
||||||
bool use_penalty_prompt_tokens = false;
|
bool use_penalty_prompt_tokens = false;
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
|
|
12
llama.h
12
llama.h
|
@ -924,6 +924,18 @@ extern "C" {
|
||||||
float p,
|
float p,
|
||||||
size_t min_keep);
|
size_t min_keep);
|
||||||
|
|
||||||
|
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
|
||||||
|
LLAMA_API void llama_sample_dry(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_token_data_array * candidates,
|
||||||
|
const llama_token * last_tokens,
|
||||||
|
int last_tokens_size,
|
||||||
|
float dry_base,
|
||||||
|
float dry_multiplier,
|
||||||
|
int dry_allowed_length,
|
||||||
|
const llama_token * seq_breakers,
|
||||||
|
int seq_breakers_size);
|
||||||
|
|
||||||
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||||
LLAMA_API void llama_sample_tail_free(
|
LLAMA_API void llama_sample_tail_free(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue