fixed various issues with sampler pointed out by original creator

This commit is contained in:
l3utterfly 2024-04-29 10:01:50 +09:00
parent 4d603e3520
commit 75beda2a84
3 changed files with 36 additions and 30 deletions

View file

@ -275,9 +275,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
const bool penalize_nl = params.penalize_nl;
// DRY sampler parameters
const float dry_multiplier = params.dry_multiplier;
const float dry_base = params.dry_base;
const int dry_allowed_length = params.dry_allowed_length;
const float dry_multiplier = params.dry_multiplier;
const float dry_base = params.dry_base;
const uint32_t dry_allowed_length = params.dry_allowed_length;
auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
@ -320,11 +320,11 @@ static llama_token_data_array llama_sampling_prepare_impl(
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
// DRY penalties (multiplier > 0 means enabled)
if(dry_multiplier > 0.0f) {
if (dry_multiplier > 0.0f) {
llama_sample_dry(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size());
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
}
if (!penalize_nl) {

View file

@ -43,7 +43,7 @@ typedef struct llama_sampling_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
float dry_base = 1.75f;
int dry_allowed_length = 2;
uint32_t dry_allowed_length = 2;
std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,
@ -64,7 +64,7 @@ typedef struct llama_sampling_params {
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
std::vector<llama_token> penalty_prompt_tokens;
std::vector<llama_token> dry_sequence_breakers; // sequence breakers for the DRY sampler
std::vector<llama_token> dry_seq_breakers; // sequence breakers for the DRY sampler
bool use_penalty_prompt_tokens = false;
} llama_sampling_params;

View file

@ -13233,15 +13233,15 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
}
}
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
// sanity check
GGML_ASSERT(last_tokens_size > 0);
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, int dry_seq_breakers_size) {
// skip dry sampler if we don't have a previous token
if (last_tokens_size < 1) return;
// get the last token
auto last_token = last_tokens[last_tokens_size - 1];
// if last token is part of the sequence breakers, skip whole sampler
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) {
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) {
return;
}
@ -13250,21 +13250,26 @@ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candi
// loop through each previous token (exclude the last token)
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
// skip if the compare token if it's not the same as the last token
if(last_tokens[i] != last_token) {
// skip if the compare token is not the same as the last token
if (last_tokens[i] != last_token) {
continue;
}
// get the next token (i + 1 is always less than last_tokens_size)
auto next_token = last_tokens[i + 1];
// try to extend the match backwards (match length starts a 1 because last token is already matched)
// if next token is part of the sequence breakers, skip
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) {
continue;
}
// try to extend the match backwards (match length starts at 1 because last token is already matched)
size_t match_length = 1;
// loop through the previous tokens
for(;; match_length++) {
for (;; match_length++) {
// if we have reached the start of our last tokens, break
if(i < match_length) break;
if (i < match_length) break;
// compare token starts at our prev index, going backwards by match length
auto compare_token = last_tokens[i - match_length];
@ -13272,13 +13277,15 @@ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candi
// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
// if compare token is part of the sequence breakers, break out of the match
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
break;
// break out of the match if any tokens don't match
if(compare_token != head_token)
if (compare_token != head_token) {
break;
}
// if compare token is part of the sequence breakers, break out of the match
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) {
break;
}
}
// Check if the next token exists in the map
@ -13298,12 +13305,11 @@ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candi
auto next_token = pair.first;
auto match_length = pair.second;
// if the match length is greater than our allowed length in config, we apply penalities
if(match_length > dry_allowed_length) {
// if the match length is greater than or equal to our allowed length in config, we apply penalities
if (match_length >= dry_allowed_length) {
// find our next token in the candidates->data
size_t i = 0;
for (; i < candidates->size; ++i) {
for (size_t i = 0; i < candidates->size; ++i) {
if (candidates->data[i].id == next_token) {
// calculate the penalty
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
@ -13444,7 +13450,7 @@ void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * c
const int64_t t_start_sample_us = ggml_time_us();
// no need to do anything if there is only one (or zero) candidates
if(candidates_p->size <= 1) {
if (candidates_p->size <= 1) {
return;
}
@ -13678,7 +13684,7 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_
t_start_sample_us = ggml_time_us();
// Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
size_t X_idx = std::distance(candidates->data, std::find_if (candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return candidate.id == X;
}));
float observed_surprise = -log2f(candidates->data[X_idx].p);
@ -13700,7 +13706,7 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok
llama_sample_softmax(ctx, candidates);
// Truncate the words with surprise values greater than mu
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
candidates->size = std::distance(candidates->data, std::find_if (candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return -log2f(candidate.p) > *mu;
}));
@ -13720,7 +13726,7 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok
t_start_sample_us = ggml_time_us();
// Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
size_t X_idx = std::distance(candidates->data, std::find_if (candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return candidate.id == X;
}));
float observed_surprise = -log2f(candidates->data[X_idx].p);
@ -15770,7 +15776,7 @@ uint64_t llama_model_n_params(const struct llama_model * model) {
}
struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) {
auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(),
auto it = std::find_if (model->tensors_by_name.begin(), model->tensors_by_name.end(),
[name](const std::pair<std::string, struct ggml_tensor *> & it) {
return it.first == name;
});