Attempt at slightly optimized vector of strings DRY implementation
This commit is contained in:
parent
9105cf435b
commit
6579e64f26
6 changed files with 213 additions and 59 deletions
|
@ -433,10 +433,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
{
|
||||
const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n);
|
||||
if (penalty_tokens_used_size) {
|
||||
llama_sample_dry(&cur_p,
|
||||
llama_sample_dry(ctx_main, &cur_p,
|
||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
|
||||
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
|
||||
params.dry_seq_breakers);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -46,6 +46,8 @@ typedef struct llama_sampling_params {
|
|||
uint32_t dry_allowed_length = 2;
|
||||
int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
|
||||
std::vector<std::string> dry_seq_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||
|
||||
std::vector<llama_sampler_type> samplers_sequence = {
|
||||
llama_sampler_type::TOP_K,
|
||||
llama_sampler_type::TFS_Z,
|
||||
|
@ -63,9 +65,8 @@ typedef struct llama_sampling_params {
|
|||
float cfg_scale = 1.f; // how strong is guidance
|
||||
|
||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||
|
||||
std::vector<llama_token> penalty_prompt_tokens;
|
||||
std::vector<llama_token> dry_seq_breakers; // sequence breakers for the DRY sampler
|
||||
|
||||
bool use_penalty_prompt_tokens = false;
|
||||
} llama_sampling_params;
|
||||
|
||||
|
|
|
@ -1085,16 +1085,17 @@ extern "C" {
|
|||
float p,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
|
||||
LLAMA_API void llama_sample_dry(
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t last_tokens_size,
|
||||
float dry_base,
|
||||
float dry_multiplier,
|
||||
int dry_allowed_length,
|
||||
const llama_token * dry_seq_breakers,
|
||||
size_t dry_seq_breakers_size);
|
||||
// /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
|
||||
// LLAMA_API void llama_sample_dry(
|
||||
// struct llama_context * ctx,
|
||||
// llama_token_data_array * candidates,
|
||||
// const llama_token * last_tokens,
|
||||
// size_t last_tokens_size,
|
||||
// float dry_base,
|
||||
// float dry_multiplier,
|
||||
// int dry_allowed_length,
|
||||
// const std::vector<std::string>
|
||||
// & dry_seq_breakers);
|
||||
|
||||
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
LLAMA_API void llama_sample_tail_free(
|
||||
|
@ -1246,6 +1247,18 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
|||
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
|
||||
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
|
||||
|
||||
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
|
||||
LLAMA_API void llama_sample_dry(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t last_tokens_size,
|
||||
float dry_base,
|
||||
float dry_multiplier,
|
||||
int dry_allowed_length,
|
||||
const std::vector<std::string>
|
||||
& dry_seq_breakers);
|
||||
|
||||
#endif // LLAMA_API_INTERNAL
|
||||
|
||||
#endif // LLAMA_H
|
||||
|
|
|
@ -232,94 +232,230 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|||
}
|
||||
}
|
||||
|
||||
void llama_sample_dry_impl(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
|
||||
// skip dry sampler if we don't have a previous token
|
||||
if (last_tokens_size < 1) return;
|
||||
std::vector<llama_token> llama_tokenize(
|
||||
const struct llama_context * ctx,
|
||||
const std::string & text,
|
||||
bool add_special,
|
||||
bool parse_special) {
|
||||
return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special);
|
||||
}
|
||||
|
||||
// get the last token
|
||||
auto last_token = last_tokens[last_tokens_size - 1];
|
||||
std::vector<llama_token> llama_tokenize(
|
||||
const struct llama_model * model,
|
||||
const std::string & text,
|
||||
bool add_special,
|
||||
bool parse_special) {
|
||||
// upper limit for the number of tokens
|
||||
int n_tokens = text.length() + 2 * add_special;
|
||||
std::vector<llama_token> result(n_tokens);
|
||||
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// if last token is part of the sequence breakers, skip whole sampler
|
||||
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) {
|
||||
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
|
||||
std::string text;
|
||||
text.resize(std::max(text.capacity(), tokens.size()));
|
||||
int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
||||
if (n_chars < 0) {
|
||||
text.resize(-n_chars);
|
||||
n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
||||
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
|
||||
}
|
||||
|
||||
text.resize(n_chars);
|
||||
|
||||
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
|
||||
return text;
|
||||
}
|
||||
|
||||
std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special) {
|
||||
std::vector<llama_token> tokens = {token};
|
||||
return llama_detokenize(ctx, tokens, special);
|
||||
}
|
||||
|
||||
// Constants for preventing overflow
|
||||
const float FLOAT_MAX_LOG = 88.7228391f;
|
||||
const int MAX_CHAR_LEN = 40;
|
||||
const int MAX_SEQ_LEN = 20;
|
||||
|
||||
|
||||
void llama_sample_dry_impl(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers) {
|
||||
if (last_tokens_size < 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// create an unordered map of "next tokens" <-> max match length
|
||||
// Cache for token-to-string conversions
|
||||
std::unordered_map<llama_token, std::string> token_to_string_cache;
|
||||
// Store sequence breakers for more efficient lookup
|
||||
std::unordered_multimap<std::string, std::vector<std::string>> restart_sequences;
|
||||
|
||||
auto detokenize_with_cache = [&](llama_token token) -> std::string {
|
||||
auto it = token_to_string_cache.find(token);
|
||||
if (it != token_to_string_cache.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::string token_str = llama_detokenize_single(ctx, token, false);
|
||||
token_to_string_cache[token] = token_str;
|
||||
return token_str;
|
||||
};
|
||||
|
||||
// Pre-process dry_seq_breakers
|
||||
for (const auto& breaker : dry_seq_breakers) {
|
||||
std::string breaker_trimmed = breaker.substr(0, MAX_CHAR_LEN);
|
||||
std::vector<llama_token> tokens = llama_tokenize(ctx, breaker_trimmed, false, false);
|
||||
|
||||
if (!tokens.empty()) {
|
||||
std::string head = detokenize_with_cache(tokens[0]);
|
||||
std::vector<std::string> tail;
|
||||
|
||||
for (size_t i = 1; i < tokens.size() && i <= MAX_SEQ_LEN; ++i) {
|
||||
tail.push_back(detokenize_with_cache(tokens[i]));
|
||||
}
|
||||
restart_sequences.emplace(head, tail);
|
||||
}
|
||||
}
|
||||
|
||||
// Find max repetition length considering restart sequences
|
||||
int rep_limit = last_tokens_size;
|
||||
|
||||
for (size_t i = 0; i < last_tokens_size; ++i) {
|
||||
size_t ix = last_tokens_size - 1 - i;
|
||||
std::string token_str = detokenize_with_cache(last_tokens[ix]);
|
||||
|
||||
// Check if the token is a potential sequence breaker
|
||||
auto its = restart_sequences.equal_range(token_str);
|
||||
if (its.first == restart_sequences.end()) continue;
|
||||
|
||||
int longest_match = -1;
|
||||
// Check all potential sequence breakers starting with this token
|
||||
for (auto it = its.first; it != its.second; ++it) {
|
||||
int seq_len = (int)it->second.size();
|
||||
if (seq_len > longest_match && seq_len <= i) {
|
||||
bool match = true;
|
||||
// Check if the following tokens match the sequence breaker
|
||||
for (size_t offset = 0; offset < seq_len; ++offset) {
|
||||
if (it->second[offset] != detokenize_with_cache(last_tokens[ix + 1 + offset])) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
longest_match = seq_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (longest_match >= 0) {
|
||||
rep_limit = static_cast<int>(i) - longest_match;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (rep_limit <= dry_allowed_length) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Store max match length for each token
|
||||
std::unordered_map<llama_token, size_t> match_lengths;
|
||||
|
||||
// loop through each previous token (exclude the last token)
|
||||
// Find repeated sequences
|
||||
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
|
||||
// skip if the compare token is not the same as the last token
|
||||
if (last_tokens[i] != last_token) {
|
||||
if (last_tokens[i] != last_tokens[last_tokens_size - 1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// get the next token (i + 1 is always less than last_tokens_size)
|
||||
auto next_token = last_tokens[i + 1];
|
||||
std::string next_token_str = detokenize_with_cache(next_token);
|
||||
|
||||
// if next token is part of the sequence breakers, skip
|
||||
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) {
|
||||
// Skip if next token is a sequence breaker
|
||||
auto its = restart_sequences.equal_range(next_token_str);
|
||||
if (its.first != restart_sequences.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// try to extend the match backwards (match length starts at 1 because last token is already matched)
|
||||
size_t match_length = 1;
|
||||
|
||||
// loop through the previous tokens
|
||||
// Extend match as far as possible
|
||||
for (;; match_length++) {
|
||||
// if we have reached the start of our last tokens, break
|
||||
if (i < match_length) break;
|
||||
|
||||
// compare token starts at our prev index, going backwards by match length
|
||||
auto compare_token = last_tokens[i - match_length];
|
||||
|
||||
// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
|
||||
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
|
||||
|
||||
// break out of the match if any tokens don't match
|
||||
if (compare_token != head_token) {
|
||||
if (i < match_length || match_length > rep_limit) {
|
||||
break;
|
||||
}
|
||||
|
||||
// if compare token is part of the sequence breakers, break out of the match
|
||||
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) {
|
||||
auto compare_token = last_tokens[i - match_length];
|
||||
std::string compare_token_str = detokenize_with_cache(compare_token);
|
||||
|
||||
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
|
||||
std::string head_token_str = detokenize_with_cache(head_token);
|
||||
|
||||
if (compare_token_str != head_token_str) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Check if we've hit a sequence breaker
|
||||
its = restart_sequences.equal_range(compare_token_str);
|
||||
if (its.first != restart_sequences.end()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the next token exists in the map
|
||||
// Update max match length for this token
|
||||
auto it = match_lengths.find(next_token);
|
||||
|
||||
if (it == match_lengths.end()) {
|
||||
// Key does not exist, insert the new value
|
||||
match_lengths[next_token] = match_length;
|
||||
} else {
|
||||
// Key exists, update it with the max of the new value or the existing value
|
||||
it->second = std::max(it->second, match_length);
|
||||
}
|
||||
}
|
||||
|
||||
// apply penalties
|
||||
// Calculate max safe exponent
|
||||
int max_exponent = 0;
|
||||
if (dry_base > 1.000001f) {
|
||||
max_exponent = static_cast<int>(FLOAT_MAX_LOG / log(dry_base));
|
||||
}
|
||||
|
||||
#ifdef DEBUG
|
||||
LLAMA_LOG_INFO("DRY Sampling parameters:\n");
|
||||
LLAMA_LOG_INFO(" dry_base: %f\n", dry_base);
|
||||
LLAMA_LOG_INFO(" dry_multiplier: %f\n", dry_multiplier);
|
||||
LLAMA_LOG_INFO(" dry_allowed_length: %d\n", dry_allowed_length);
|
||||
LLAMA_LOG_INFO(" max_exponent: %d\n", max_exponent);
|
||||
LLAMA_LOG_INFO("DRY penalties [");
|
||||
#endif
|
||||
|
||||
// Apply penalties
|
||||
for (const auto& pair : match_lengths) {
|
||||
auto next_token = pair.first;
|
||||
auto match_length = pair.second;
|
||||
|
||||
// if the match length is greater than or equal to our allowed length in config, we apply penalities
|
||||
if (match_length >= (size_t)dry_allowed_length) {
|
||||
|
||||
// find our next token in the candidates->data
|
||||
if (match_length >= static_cast<size_t>(dry_allowed_length)) {
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
if (candidates->data[i].id == next_token) {
|
||||
// calculate the penalty
|
||||
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
|
||||
|
||||
// apply the dry penalty
|
||||
int repeat_exp = static_cast<int>(match_length - dry_allowed_length);
|
||||
if (max_exponent > 0 && repeat_exp > max_exponent) {
|
||||
repeat_exp = max_exponent;
|
||||
}
|
||||
float penalty = dry_multiplier * pow(dry_base, static_cast<float>(repeat_exp));
|
||||
candidates->data[i].logit -= penalty;
|
||||
|
||||
#ifdef DEBUG
|
||||
LLAMA_LOG_INFO(" Token %d: %s (Penalty: %.2f)", next_token, detokenize_with_cache(next_token).c_str(), penalty);
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef DEBUG
|
||||
LLAMA_LOG_INFO("]\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
|
|
|
@ -28,7 +28,11 @@ void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_
|
|||
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
|
||||
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sample_dry_impl (llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size);
|
||||
std::vector<llama_token> llama_tokenize(const struct llama_context * ctx, const std::string & text, bool add_special, bool parse_special);
|
||||
std::vector<llama_token> llama_tokenize(const struct llama_model * model, const std::string & text, bool add_special, bool parse_special);
|
||||
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special);
|
||||
std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special);
|
||||
void llama_sample_dry_impl (struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers);
|
||||
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
|
||||
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
|
||||
|
|
|
@ -18935,8 +18935,8 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
|
|||
llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
|
||||
}
|
||||
|
||||
void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
|
||||
llama_sample_dry_impl(candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers, dry_seq_breakers_size);
|
||||
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers) {
|
||||
llama_sample_dry_impl(ctx, candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers);
|
||||
}
|
||||
|
||||
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue