added sample_dry_impl
This commit is contained in:
parent
2f9a36a4f9
commit
802ddd78bf
1 changed files with 90 additions and 0 deletions
|
@ -232,6 +232,96 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_sample_dry_impl(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
|
||||||
|
// skip dry sampler if we don't have a previous token
|
||||||
|
if (last_tokens_size < 1) return;
|
||||||
|
|
||||||
|
// get the last token
|
||||||
|
auto last_token = last_tokens[last_tokens_size - 1];
|
||||||
|
|
||||||
|
// if last token is part of the sequence breakers, skip whole sampler
|
||||||
|
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// create an unordered map of "next tokens" <-> max match length
|
||||||
|
std::unordered_map<llama_token, size_t> match_lengths;
|
||||||
|
|
||||||
|
// loop through each previous token (exclude the last token)
|
||||||
|
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
|
||||||
|
// skip if the compare token is not the same as the last token
|
||||||
|
if (last_tokens[i] != last_token) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the next token (i + 1 is always less than last_tokens_size)
|
||||||
|
auto next_token = last_tokens[i + 1];
|
||||||
|
|
||||||
|
// if next token is part of the sequence breakers, skip
|
||||||
|
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// try to extend the match backwards (match length starts at 1 because last token is already matched)
|
||||||
|
size_t match_length = 1;
|
||||||
|
|
||||||
|
// loop through the previous tokens
|
||||||
|
for (;; match_length++) {
|
||||||
|
// if we have reached the start of our last tokens, break
|
||||||
|
if (i < match_length) break;
|
||||||
|
|
||||||
|
// compare token starts at our prev index, going backwards by match length
|
||||||
|
auto compare_token = last_tokens[i - match_length];
|
||||||
|
|
||||||
|
// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
|
||||||
|
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
|
||||||
|
|
||||||
|
// break out of the match if any tokens don't match
|
||||||
|
if (compare_token != head_token) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if compare token is part of the sequence breakers, break out of the match
|
||||||
|
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the next token exists in the map
|
||||||
|
auto it = match_lengths.find(next_token);
|
||||||
|
|
||||||
|
if (it == match_lengths.end()) {
|
||||||
|
// Key does not exist, insert the new value
|
||||||
|
match_lengths[next_token] = match_length;
|
||||||
|
} else {
|
||||||
|
// Key exists, update it with the max of the new value or the existing value
|
||||||
|
it->second = std::max(it->second, match_length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply penalties
|
||||||
|
for (const auto& pair : match_lengths) {
|
||||||
|
auto next_token = pair.first;
|
||||||
|
auto match_length = pair.second;
|
||||||
|
|
||||||
|
// if the match length is greater than or equal to our allowed length in config, we apply penalities
|
||||||
|
if (match_length >= dry_allowed_length) {
|
||||||
|
|
||||||
|
// find our next token in the candidates->data
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
if (candidates->data[i].id == next_token) {
|
||||||
|
// calculate the penalty
|
||||||
|
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
|
||||||
|
|
||||||
|
// apply the dry penalty
|
||||||
|
candidates->data[i].logit -= penalty;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
|
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||||
if (z >= 1.0f || candidates->size <= 2) {
|
if (z >= 1.0f || candidates->size <= 2) {
|
||||||
return;
|
return;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue