From 802ddd78bf9cbc12336f95ae6e9af321974462d3 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Jul 2024 19:41:47 +0900 Subject: [PATCH] added sample_dry_impl --- src/llama-sampling.cpp | 90 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8910f6d65..d41218c70 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -232,6 +232,96 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra } } +void llama_sample_dry_impl(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) { + // skip dry sampler if we don't have a previous token + if (last_tokens_size < 1) return; + + // get the last token + auto last_token = last_tokens[last_tokens_size - 1]; + + // if last token is part of the sequence breakers, skip whole sampler + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) { + return; + } + + // create an unordered map of "next tokens" <-> max match length + std::unordered_map match_lengths; + + // loop through each previous token (exclude the last token) + for (size_t i = 0; i < last_tokens_size - 1; ++i) { + // skip if the compare token is not the same as the last token + if (last_tokens[i] != last_token) { + continue; + } + + // get the next token (i + 1 is always less than last_tokens_size) + auto next_token = last_tokens[i + 1]; + + // if next token is part of the sequence breakers, skip + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) { + continue; + } + + // try to extend the match backwards (match length starts at 1 because last token is already matched) + size_t match_length = 1; + + // loop through the previous tokens + for (;; match_length++) { + // if we have reached the start of our last tokens, break + if (i < match_length) break; + + // compare token starts at our prev index, going backwards by match length + auto compare_token = last_tokens[i - match_length]; + + // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself + auto head_token = last_tokens[last_tokens_size - 1 - match_length]; + + // break out of the match if any tokens don't match + if (compare_token != head_token) { + break; + } + + // if compare token is part of the sequence breakers, break out of the match + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) { + break; + } + } + + // Check if the next token exists in the map + auto it = match_lengths.find(next_token); + + if (it == match_lengths.end()) { + // Key does not exist, insert the new value + match_lengths[next_token] = match_length; + } else { + // Key exists, update it with the max of the new value or the existing value + it->second = std::max(it->second, match_length); + } + } + + // apply penalties + for (const auto& pair : match_lengths) { + auto next_token = pair.first; + auto match_length = pair.second; + + // if the match length is greater than or equal to our allowed length in config, we apply penalities + if (match_length >= dry_allowed_length) { + + // find our next token in the candidates->data + for (size_t i = 0; i < candidates->size; ++i) { + if (candidates->data[i].id == next_token) { + // calculate the penalty + float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); + + // apply the dry penalty + candidates->data[i].logit -= penalty; + break; + } + } + } + } +} + void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return;