fixed various issues with sampler pointed out by original creator
This commit is contained in:
parent
4d603e3520
commit
75beda2a84
3 changed files with 36 additions and 30 deletions
|
@ -277,7 +277,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
// DRY sampler parameters
|
// DRY sampler parameters
|
||||||
const float dry_multiplier = params.dry_multiplier;
|
const float dry_multiplier = params.dry_multiplier;
|
||||||
const float dry_base = params.dry_base;
|
const float dry_base = params.dry_base;
|
||||||
const int dry_allowed_length = params.dry_allowed_length;
|
const uint32_t dry_allowed_length = params.dry_allowed_length;
|
||||||
|
|
||||||
auto & prev = ctx_sampling->prev;
|
auto & prev = ctx_sampling->prev;
|
||||||
auto & cur = ctx_sampling->cur;
|
auto & cur = ctx_sampling->cur;
|
||||||
|
@ -324,7 +324,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
llama_sample_dry(ctx_main, &cur_p,
|
llama_sample_dry(ctx_main, &cur_p,
|
||||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||||
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
|
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
|
||||||
params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size());
|
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!penalize_nl) {
|
if (!penalize_nl) {
|
||||||
|
|
|
@ -43,7 +43,7 @@ typedef struct llama_sampling_params {
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||||
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
|
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
|
||||||
float dry_base = 1.75f;
|
float dry_base = 1.75f;
|
||||||
int dry_allowed_length = 2;
|
uint32_t dry_allowed_length = 2;
|
||||||
|
|
||||||
std::vector<llama_sampler_type> samplers_sequence = {
|
std::vector<llama_sampler_type> samplers_sequence = {
|
||||||
llama_sampler_type::TOP_K,
|
llama_sampler_type::TOP_K,
|
||||||
|
@ -64,7 +64,7 @@ typedef struct llama_sampling_params {
|
||||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||||
|
|
||||||
std::vector<llama_token> penalty_prompt_tokens;
|
std::vector<llama_token> penalty_prompt_tokens;
|
||||||
std::vector<llama_token> dry_sequence_breakers; // sequence breakers for the DRY sampler
|
std::vector<llama_token> dry_seq_breakers; // sequence breakers for the DRY sampler
|
||||||
bool use_penalty_prompt_tokens = false;
|
bool use_penalty_prompt_tokens = false;
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
|
|
36
llama.cpp
36
llama.cpp
|
@ -13233,15 +13233,15 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
|
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, int dry_seq_breakers_size) {
|
||||||
// sanity check
|
// skip dry sampler if we don't have a previous token
|
||||||
GGML_ASSERT(last_tokens_size > 0);
|
if (last_tokens_size < 1) return;
|
||||||
|
|
||||||
// get the last token
|
// get the last token
|
||||||
auto last_token = last_tokens[last_tokens_size - 1];
|
auto last_token = last_tokens[last_tokens_size - 1];
|
||||||
|
|
||||||
// if last token is part of the sequence breakers, skip whole sampler
|
// if last token is part of the sequence breakers, skip whole sampler
|
||||||
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) {
|
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13250,7 +13250,7 @@ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candi
|
||||||
|
|
||||||
// loop through each previous token (exclude the last token)
|
// loop through each previous token (exclude the last token)
|
||||||
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
|
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
|
||||||
// skip if the compare token if it's not the same as the last token
|
// skip if the compare token is not the same as the last token
|
||||||
if (last_tokens[i] != last_token) {
|
if (last_tokens[i] != last_token) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -13258,7 +13258,12 @@ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candi
|
||||||
// get the next token (i + 1 is always less than last_tokens_size)
|
// get the next token (i + 1 is always less than last_tokens_size)
|
||||||
auto next_token = last_tokens[i + 1];
|
auto next_token = last_tokens[i + 1];
|
||||||
|
|
||||||
// try to extend the match backwards (match length starts a 1 because last token is already matched)
|
// if next token is part of the sequence breakers, skip
|
||||||
|
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// try to extend the match backwards (match length starts at 1 because last token is already matched)
|
||||||
size_t match_length = 1;
|
size_t match_length = 1;
|
||||||
|
|
||||||
// loop through the previous tokens
|
// loop through the previous tokens
|
||||||
|
@ -13272,15 +13277,17 @@ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candi
|
||||||
// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
|
// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
|
||||||
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
|
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
|
||||||
|
|
||||||
// if compare token is part of the sequence breakers, break out of the match
|
|
||||||
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
|
|
||||||
break;
|
|
||||||
|
|
||||||
// break out of the match if any tokens don't match
|
// break out of the match if any tokens don't match
|
||||||
if(compare_token != head_token)
|
if (compare_token != head_token) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if compare token is part of the sequence breakers, break out of the match
|
||||||
|
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check if the next token exists in the map
|
// Check if the next token exists in the map
|
||||||
auto it = match_lengths.find(next_token);
|
auto it = match_lengths.find(next_token);
|
||||||
|
|
||||||
|
@ -13298,12 +13305,11 @@ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candi
|
||||||
auto next_token = pair.first;
|
auto next_token = pair.first;
|
||||||
auto match_length = pair.second;
|
auto match_length = pair.second;
|
||||||
|
|
||||||
// if the match length is greater than our allowed length in config, we apply penalities
|
// if the match length is greater than or equal to our allowed length in config, we apply penalities
|
||||||
if(match_length > dry_allowed_length) {
|
if (match_length >= dry_allowed_length) {
|
||||||
|
|
||||||
// find our next token in the candidates->data
|
// find our next token in the candidates->data
|
||||||
size_t i = 0;
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
for (; i < candidates->size; ++i) {
|
|
||||||
if (candidates->data[i].id == next_token) {
|
if (candidates->data[i].id == next_token) {
|
||||||
// calculate the penalty
|
// calculate the penalty
|
||||||
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
|
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue