llama : add DRY sampler (#9702)
* sampling : add DRY sampler (post-refactor) * DRY: Trying to fix coauthors, removed unneeded line * DRY: Fixed redundant code * DRY: Fixed crash issue due to DRY being in chain but uninitialized --------- Co-authored-by: l3utterfly <gc.pthzfoldr@gmail.com> Co-authored-by: pi6am <34464159+pi6am@users.noreply.github.com>
This commit is contained in:
parent
d80fb71f8b
commit
ff252ea48e
17 changed files with 713 additions and 63 deletions
|
@ -1683,6 +1683,397 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|||
};
|
||||
}
|
||||
|
||||
// DRY
|
||||
|
||||
struct llama_sampler_dry {
|
||||
int32_t total_context_size;
|
||||
|
||||
const float dry_multiplier;
|
||||
const float dry_base;
|
||||
const int32_t dry_allowed_length;
|
||||
const int32_t dry_penalty_last_n;
|
||||
|
||||
std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
|
||||
std::vector<int> dry_repeat_count;
|
||||
std::unordered_map<llama_token, int> dry_max_token_repeat;
|
||||
ring_buffer<llama_token> last_tokens;
|
||||
};
|
||||
|
||||
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
||||
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
|
||||
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
|
||||
std::string word = llama_detokenize(vocab, {token_id}, true);
|
||||
if (word.find(str) != std::string::npos) {
|
||||
token_sequences.emplace(token_id, std::vector<llama_token>());
|
||||
} else {
|
||||
size_t word_len = word.size(), str_len = str.size();
|
||||
size_t pos = -1;
|
||||
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
|
||||
bool match = true;
|
||||
size_t i;
|
||||
for (i = 1; i < str_len && i + pos < word_len; ++i) {
|
||||
if (word[pos + i] != str[i]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
|
||||
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
|
||||
tokenization.resize(max_tail_len);
|
||||
}
|
||||
|
||||
// Ensure we don't already have a duplicate matching tokenization
|
||||
auto its = token_sequences.equal_range(token_id);
|
||||
bool found = false;
|
||||
for (auto it = its.first; it != its.second; ++it) {
|
||||
if (tokenization == it->second) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
token_sequences.emplace(token_id, tokenization);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "dry";
|
||||
}
|
||||
|
||||
static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
|
||||
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
ctx->last_tokens.push_back(token);
|
||||
}
|
||||
|
||||
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
||||
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||
|
||||
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
|
||||
int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
|
||||
|
||||
if (last_n_repeat <= ctx->dry_allowed_length) {
|
||||
return;
|
||||
}
|
||||
|
||||
ctx->dry_repeat_count.assign(last_n_repeat, 0);
|
||||
ctx->dry_max_token_repeat.clear();
|
||||
|
||||
// Step 1: Look for restart sequences to limit the maximum repetition length.
|
||||
// Work backwards through the context looking for any token that begins a restart sequence.
|
||||
//
|
||||
// The collection `restart_sequences` is a mapping from a "head" token to all "tail"
|
||||
// sequences that together comprise a restart sequence. This allows us to quickly check
|
||||
// whether each token is the head of a complete sequence. Most restart sequences are actually
|
||||
// a single token, and for these the "tail" is an empty vector.
|
||||
//
|
||||
// If the token is a "head", test all restart sequences that begin with this token
|
||||
// (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
|
||||
// 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
|
||||
// longest matching sequence (if any) is used to limit the maximum repetition length.
|
||||
//
|
||||
// Note that in the case case of a short sequence contained in a longer one, this might fail to
|
||||
// find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
|
||||
// restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
|
||||
// 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
|
||||
//
|
||||
// This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
|
||||
// have already clamped the maximum tail sequence length when generating `restart_sequences`.
|
||||
// With clamping, this scan is O(N) in the context length.
|
||||
|
||||
int rep_limit = last_n_repeat;
|
||||
for (int i = 0; i < last_n_repeat; ++i) {
|
||||
llama_token token = ctx->last_tokens.rat(i);
|
||||
auto its = ctx->dry_processed_breakers.equal_range(token);
|
||||
if (its.first == ctx->dry_processed_breakers.end()) {
|
||||
continue;
|
||||
}
|
||||
int longest_match = -1;
|
||||
for (auto it = its.first; it != its.second; ++it) {
|
||||
// Note that (*it) does not contain the head character, so seq_len will be
|
||||
// the restart sequence length minus 1.
|
||||
// In the common case of a single-token restart sequence, (*it) will be empty
|
||||
// and we will trivially match.
|
||||
int seq_len = (int)it->second.size();
|
||||
if (seq_len > longest_match && seq_len <= (int)i) {
|
||||
bool match = true;
|
||||
for (int offset = 0; offset < seq_len; ++offset) {
|
||||
// The -1 when indexing `last_tokens` is because we already matched the head.
|
||||
if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
longest_match = seq_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (longest_match >= 0) {
|
||||
// We found a restart sequence starting `i` tokens from the end and continuing for
|
||||
// `longest_match` tokens.
|
||||
rep_limit = i - longest_match;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (rep_limit < ctx->dry_allowed_length) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
|
||||
// the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
|
||||
// elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
|
||||
//
|
||||
// This algorithm is not currently documented on Wikipedia, but there is a clear description here:
|
||||
// https://ivanyu.me/blog/2014/10/15/z-algorithm/
|
||||
//
|
||||
// The code below is adapted from the public domain implementation by the same author here:
|
||||
// https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
|
||||
//
|
||||
// Example:
|
||||
// Last N tokens: a b c c b c y a b c
|
||||
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
|
||||
// ^
|
||||
// This `3` means that the last three tokens of the context (a b c) also appear here.
|
||||
//
|
||||
// This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
|
||||
// for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
|
||||
// repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
|
||||
// ensure that the inner while loops only examine each token in the context once as the outer
|
||||
// for loop iterates over the context.
|
||||
|
||||
{
|
||||
const int last = last_n_repeat - 1;
|
||||
int rt = 0, lt = 0;
|
||||
|
||||
for (int k = 1; k < last_n_repeat; ++k) {
|
||||
if (k > rt) {
|
||||
// If k is outside the current Z-box, do naive computation.
|
||||
int n = 0;
|
||||
while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
|
||||
++n;
|
||||
}
|
||||
ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
|
||||
if (n > 0) {
|
||||
lt = k;
|
||||
rt = k+n-1;
|
||||
}
|
||||
} else {
|
||||
// If k is inside the current Z-box, consider two cases.
|
||||
|
||||
int p = k - lt; // Pair index.
|
||||
int right_part_len = rt - k + 1;
|
||||
|
||||
if (ctx->dry_repeat_count[last - p] < right_part_len) {
|
||||
int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
|
||||
ctx->dry_repeat_count[last - k] = n;
|
||||
} else {
|
||||
int i = rt + 1;
|
||||
while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
|
||||
i += 1;
|
||||
}
|
||||
|
||||
int n = std::min(i - k, rep_limit);
|
||||
ctx->dry_repeat_count[last - k] = n;
|
||||
lt = k;
|
||||
rt = i - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
|
||||
// that would be generated by emitting each new token that would extend a sequence.
|
||||
//
|
||||
// Following the same example as above:
|
||||
// Last N tokens: a b c c b c y a b c
|
||||
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
|
||||
//
|
||||
// For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
|
||||
// c: 3 -> 4 (from `a b c` to `a b c c`)
|
||||
// b: 1 -> 2 (from `c` to `c b`)
|
||||
// y: 2 -> 3 (from `b c` to `b c y`)
|
||||
|
||||
for (int i = 0; i < last_n_repeat - 1; ++i) {
|
||||
int repeat_len = ctx->dry_repeat_count[i];
|
||||
if (repeat_len >= ctx->dry_allowed_length) {
|
||||
// This token ends a repeat, so the next token would continue one.
|
||||
// By convention, the value of `repeat_len` only includes the tokens currently
|
||||
// in the context, not the new token that would be added.
|
||||
llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
|
||||
// Track the maximum sequence ending in this token.
|
||||
const auto& it = ctx->dry_max_token_repeat.find(token);
|
||||
if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
|
||||
ctx->dry_max_token_repeat[token] = repeat_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
|
||||
|
||||
// Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
|
||||
// Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
|
||||
const float FLOAT_MAX_LOG = 88.7228391f;
|
||||
int max_exponent = 0;
|
||||
if (ctx->dry_base > 1.000001f) {
|
||||
max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
|
||||
if (af_kvp != ctx->dry_max_token_repeat.end()) {
|
||||
// Check all sequence breakers starting with this token
|
||||
auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
|
||||
bool is_single_token_breaker = false;
|
||||
|
||||
for (auto it = range.first; it != range.second; ++it) {
|
||||
if (it->second.empty()) {
|
||||
is_single_token_breaker = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply penalty only if it's not a single-token sequence breaker
|
||||
if (!is_single_token_breaker) {
|
||||
int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
|
||||
if (max_exponent > 0 && repeat_exp > max_exponent) {
|
||||
repeat_exp = max_exponent;
|
||||
}
|
||||
float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
|
||||
cur_p->data[i].logit -= penalty;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cur_p->sorted = false;
|
||||
}
|
||||
|
||||
static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||
ctx->last_tokens.clear();
|
||||
ctx->dry_repeat_count.clear();
|
||||
ctx->dry_max_token_repeat.clear();
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||
|
||||
// nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
|
||||
auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
|
||||
// Copy the state, including the processed breakers
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_dry *) result->ctx;
|
||||
result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
|
||||
result_ctx->dry_repeat_count = ctx->dry_repeat_count;
|
||||
result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
|
||||
result_ctx->last_tokens = ctx->last_tokens;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void llama_sampler_dry_free(struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_dry *) smpl->ctx;
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_dry_i = {
|
||||
/* .name = */ llama_sampler_dry_name,
|
||||
/* .accept = */ llama_sampler_dry_accept,
|
||||
/* .apply = */ llama_sampler_dry_apply,
|
||||
/* .reset = */ llama_sampler_dry_reset,
|
||||
/* .clone = */ llama_sampler_dry_clone,
|
||||
/* .free = */ llama_sampler_dry_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
||||
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
|
||||
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
|
||||
const int MAX_CHAR_LEN = 40;
|
||||
const int MAX_SEQ_LEN = 20;
|
||||
|
||||
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
||||
|
||||
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
||||
// Process sequence breakers
|
||||
for (size_t i = 0; i < num_breakers; ++i) {
|
||||
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
|
||||
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string sequence_break(seq_breakers[i]);
|
||||
if (sequence_break.empty()) {
|
||||
LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (sequence_break.size() > MAX_CHAR_LEN) {
|
||||
LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
|
||||
sequence_break.resize(MAX_CHAR_LEN);
|
||||
}
|
||||
|
||||
get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
|
||||
}
|
||||
}
|
||||
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_dry_i,
|
||||
/* .ctx = */ new llama_sampler_dry {
|
||||
/* .total_context_size = */ context_size,
|
||||
/* .dry_multiplier = */ dry_multiplier,
|
||||
/* .dry_base = */ dry_base,
|
||||
/* .dry_allowed_length = */ dry_allowed_length,
|
||||
/* .dry_penalty_last_n = */ dry_penalty_last_n,
|
||||
/* .dry_processed_breakers = */ std::move(processed_breakers),
|
||||
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
|
||||
/* .dry_max_token_repeat = */ {},
|
||||
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// wrapper for test-sampling.cpp
|
||||
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
|
||||
llama_vocab dummy_vocab;
|
||||
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
|
||||
auto * ctx = (llama_sampler_dry *) result->ctx;
|
||||
|
||||
// Process the token-based sequence breakers
|
||||
ctx->dry_processed_breakers.clear();
|
||||
if (seq_breakers.empty()) {
|
||||
LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
|
||||
} else {
|
||||
for (const auto& breaker : seq_breakers) {
|
||||
if (breaker.empty()) {
|
||||
LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
|
||||
continue;
|
||||
}
|
||||
llama_token head_token = breaker[0];
|
||||
std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
|
||||
ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
|
||||
}
|
||||
|
||||
if (ctx->dry_processed_breakers.empty()) {
|
||||
LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// logit-bias
|
||||
|
||||
struct llama_sampler_logit_bias {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue