Add penalty_threshold
parameter
Only apply penalties to tokens whose relative frequency in the penalty context is less than or equal to this value.
This commit is contained in:
parent
8f1be0d42f
commit
0f7495469c
7 changed files with 46 additions and 22 deletions
|
@ -404,6 +404,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sparams.penalty_present = std::stof(argv[i]);
|
sparams.penalty_present = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--penalty-threshold") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sparams.penalty_threshold = std::stof(argv[i]);
|
||||||
} else if (arg == "--dynatemp-range") {
|
} else if (arg == "--dynatemp-range") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -976,6 +982,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
|
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
|
||||||
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
|
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
|
||||||
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
|
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
|
||||||
|
printf(" --penalty-threshold N only apply penalties to tokens whose relative frequency in the penalty context is less than or equal to this value (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_threshold);
|
||||||
printf(" --dynatemp-range N dynamic temperature range (default: %.1f, 0.0 = disabled)\n", (double)sparams.dynatemp_range);
|
printf(" --dynatemp-range N dynamic temperature range (default: %.1f, 0.0 = disabled)\n", (double)sparams.dynatemp_range);
|
||||||
printf(" --dynatemp-exp N dynamic temperature exponent (default: %.1f)\n", (double)sparams.dynatemp_exponent);
|
printf(" --dynatemp-exp N dynamic temperature exponent (default: %.1f)\n", (double)sparams.dynatemp_exponent);
|
||||||
printf(" --mirostat N use Mirostat sampling.\n");
|
printf(" --mirostat N use Mirostat sampling.\n");
|
||||||
|
@ -1717,6 +1724,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
||||||
fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false");
|
fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false");
|
||||||
fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false");
|
fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false");
|
||||||
fprintf(stream, "no_penalize_nl: %s # default: false\n", !sparams.penalize_nl ? "true" : "false");
|
fprintf(stream, "no_penalize_nl: %s # default: false\n", !sparams.penalize_nl ? "true" : "false");
|
||||||
|
fprintf(stream, "penalty_threshold: %f # default: 1.0\n", sparams.penalty_threshold);
|
||||||
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
|
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
|
||||||
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
|
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
|
||||||
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
|
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
|
||||||
|
|
|
@ -168,6 +168,7 @@ static llama_token llama_sampling_sample_impl(
|
||||||
const float penalty_repeat = params.penalty_repeat;
|
const float penalty_repeat = params.penalty_repeat;
|
||||||
const float penalty_freq = params.penalty_freq;
|
const float penalty_freq = params.penalty_freq;
|
||||||
const float penalty_present = params.penalty_present;
|
const float penalty_present = params.penalty_present;
|
||||||
|
const float penalty_threshold = params.penalty_threshold;
|
||||||
const int mirostat = params.mirostat;
|
const int mirostat = params.mirostat;
|
||||||
const float mirostat_tau = params.mirostat_tau;
|
const float mirostat_tau = params.mirostat_tau;
|
||||||
const float mirostat_eta = params.mirostat_eta;
|
const float mirostat_eta = params.mirostat_eta;
|
||||||
|
@ -215,7 +216,7 @@ static llama_token llama_sampling_sample_impl(
|
||||||
|
|
||||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||||
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present, penalty_threshold);
|
||||||
|
|
||||||
if (!penalize_nl) {
|
if (!penalize_nl) {
|
||||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||||
|
|
|
@ -34,6 +34,7 @@ typedef struct llama_sampling_params {
|
||||||
float penalty_repeat = 1.10f; // 1.0 = disabled
|
float penalty_repeat = 1.10f; // 1.0 = disabled
|
||||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||||
float penalty_present = 0.00f; // 0.0 = disabled
|
float penalty_present = 0.00f; // 0.0 = disabled
|
||||||
|
float penalty_threshold = 1.00f; // 1.0 = disabled
|
||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
|
|
|
@ -182,15 +182,18 @@ Example usage: `--temp 0.5`
|
||||||
|
|
||||||
- `--repeat-penalty N`: Control the repetition of token sequences in the generated text (default: 1.1).
|
- `--repeat-penalty N`: Control the repetition of token sequences in the generated text (default: 1.1).
|
||||||
- `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
|
- `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
|
||||||
|
- `--penalty-threshold N`: Only apply penalties to tokens whose relative frequency in the penalty context is less than or equal to this value (default: 1.0, 1.0 = disabled).
|
||||||
- `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty.
|
- `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty.
|
||||||
|
|
||||||
The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.1.
|
The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.1.
|
||||||
|
|
||||||
The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`).
|
The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`).
|
||||||
|
|
||||||
|
The `penalty-threshold` option disables penalties for very common tokens. This is designed to prevent penalizing tokens that are essential to the structure of the text, such as spaces and punctuation, very common words such as "the", names of participants in chats, brackets and tags in code, etc. For example, a value of 0.1 disables penalties for tokens that make up more than 10% of all tokens in the input.
|
||||||
|
|
||||||
Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases.
|
Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases.
|
||||||
|
|
||||||
Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
|
Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --penalty-threshold 0.1 --no-penalize-nl`
|
||||||
|
|
||||||
### Top-K Sampling
|
### Top-K Sampling
|
||||||
|
|
||||||
|
|
|
@ -9634,8 +9634,9 @@ void llama_sample_repetition_penalties(
|
||||||
size_t penalty_last_n,
|
size_t penalty_last_n,
|
||||||
float penalty_repeat,
|
float penalty_repeat,
|
||||||
float penalty_freq,
|
float penalty_freq,
|
||||||
float penalty_present) {
|
float penalty_present,
|
||||||
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
float penalty_threshold) {
|
||||||
|
if (penalty_last_n == 0 || penalty_threshold == 0.0f || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9656,6 +9657,10 @@ void llama_sample_repetition_penalties(
|
||||||
|
|
||||||
const int count = token_iter->second;
|
const int count = token_iter->second;
|
||||||
|
|
||||||
|
if (float(count) / float(penalty_last_n) > penalty_threshold) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||||
if (candidates->data[i].logit <= 0) {
|
if (candidates->data[i].logit <= 0) {
|
||||||
|
|
4
llama.h
4
llama.h
|
@ -720,6 +720,7 @@ extern "C" {
|
||||||
|
|
||||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||||
|
/// @param penalty_threshold Only apply penalties to tokens whose relative frequency in the penalty context is less than or equal to this value.
|
||||||
LLAMA_API void llama_sample_repetition_penalties(
|
LLAMA_API void llama_sample_repetition_penalties(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates,
|
||||||
|
@ -727,7 +728,8 @@ extern "C" {
|
||||||
size_t penalty_last_n,
|
size_t penalty_last_n,
|
||||||
float penalty_repeat,
|
float penalty_repeat,
|
||||||
float penalty_freq,
|
float penalty_freq,
|
||||||
float penalty_present);
|
float penalty_present,
|
||||||
|
float penalty_threshold);
|
||||||
|
|
||||||
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
||||||
/// @param logits Logits extracted from the original generation context.
|
/// @param logits Logits extracted from the original generation context.
|
||||||
|
|
|
@ -123,7 +123,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
|
||||||
|
|
||||||
static void test_repetition_penalties(
|
static void test_repetition_penalties(
|
||||||
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
||||||
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
|
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence, float penalty_threshold
|
||||||
) {
|
) {
|
||||||
GGML_ASSERT(probs.size() == expected_probs.size());
|
GGML_ASSERT(probs.size() == expected_probs.size());
|
||||||
|
|
||||||
|
@ -138,7 +138,7 @@ static void test_repetition_penalties(
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
llama_sample_softmax(nullptr, &candidates_p);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, penalty_threshold);
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
llama_sample_softmax(nullptr, &candidates_p);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
|
@ -259,13 +259,17 @@ int main(void) {
|
||||||
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
||||||
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
||||||
|
|
||||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f, 1.0f);
|
||||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 1.0f);
|
||||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 1.0f);
|
||||||
|
|
||||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 50.0f, 0.0f, 0.0f, 0.5f);
|
||||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.5f);
|
||||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
|
test_repetition_penalties({0.125f, 0.125f, 0.125f, 0.125f, 0.125f, 0.125f, 0.125f, 0.125f}, {0, 1, 2, 3, 4, 0, 0, 0, 0}, {0.25f, 0.25f, 0.25f, 0.25f, 0, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.5f);
|
||||||
|
|
||||||
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f, 1.0f);
|
||||||
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f, 1.0f);
|
||||||
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f, 1.0f);
|
||||||
|
|
||||||
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
||||||
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue