Added shift_p_min
parameter to control probabilities
* the parameter limits how far K-Shift cuts by checking probability of the last token and iterating backwards if it's not probable enough
This commit is contained in:
parent
877a495245
commit
9cae93cfd8
9 changed files with 66 additions and 41 deletions
|
@ -924,11 +924,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--k-shift"}, "N",
|
||||
string_format("k-shift sampling (default: %d, 0 = disabled)", params.sparams.k_shift),
|
||||
string_format("K-Shift sampling (default: %d, 0 = disabled)", params.sparams.k_shift),
|
||||
[](common_params & params, int value) {
|
||||
params.sparams.k_shift = value;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--shift-p-min"}, "N",
|
||||
string_format("minimum probability required for tokens to be cut by k-shift (default: %.4f, 1.0 = disabled)", params.sparams.shift_p_min),
|
||||
[](common_params & params, int value) {
|
||||
params.sparams.shift_p_min = value;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--top-k"}, "N",
|
||||
string_format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k),
|
||||
|
|
|
@ -104,35 +104,36 @@ enum dimre_method {
|
|||
|
||||
// sampler parameters
|
||||
struct common_sampler_params {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||
int32_t k_shift = 0; // 0 = disabled
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||
int32_t k_shift = 0; // 0 = disabled
|
||||
float shift_p_min = 0.0001; // >= 1 disables k-shift
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
bool ignore_eos = false;
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool no_perf = false; // disable performance metrics
|
||||
|
||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||
|
||||
|
|
|
@ -188,7 +188,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_K_SHIFT:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_k_shift (params.k_shift));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_k_shift (params.k_shift, params.shift_p_min));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
|
|
|
@ -214,10 +214,11 @@ Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dr
|
|||
### K-Shift Sampling
|
||||
|
||||
- `--k-shift N`: Shift the first token selection by cutting out N tokens from the top once (default: 0).
|
||||
- `--shift-p-min N`: Sets the minimum probability for tokens that can be cut, ensuring coherency (default: 0.0001).
|
||||
|
||||
K-Shift is a sampling method that guides models away from the most obvious output, eliciting reasoning and analysis. It cuts out k top tokens once at the beginning of inference, making sure that the dialog will start from a less obvious path without guiding the model too much. The method was mentoned in a paper [Chain-of-Thought Reasoning without Prompting](https://arxiv.org/pdf/2402.10200) as a simple trick to guiding a model towards reasoning. In practice, K-Shift can improve the quality of reasoning, help bypass bias or censorship in certain cases, and may also be used as a diagnostics tool. K-Shift is intended to be used with greedy sampling (`--k-shift 10 --top-k 1`), but can help with creative writing too - albeit, not as much as XTC. The default value is 0.
|
||||
|
||||
Example usage: `--k-shift 10`
|
||||
Example usage: `--k-shift 10 --shift-p-min 0.001`
|
||||
|
||||
### Top-K Sampling
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@
|
|||
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
|
||||
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
k_shift: 0, // <= 0 to use vocab size
|
||||
shift_p_min: 0.0001, // <= 0 to use vocab size
|
||||
top_k: 0, // 0 = disabled
|
||||
top_p: 1.0, // 1.0 = disabled
|
||||
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
|
||||
|
@ -836,6 +837,7 @@ return html`
|
|||
<summary><span class="summary-title">Further Options</span></summary>
|
||||
<fieldset class="params">
|
||||
${IntField({ label: "K-Shift", title: "Cuts out first k tokens once at the start of sampling. Intended to use with greedy sampling.", max: 100, min: 0, step: 1, name: "k_shift", value: params.value.k_shift })}
|
||||
${FloatField({ label: "Shift min probability", title: "Sets the minimum probability limit for tokens cut by K-Shift.", max: 1.0, min: 0.0, step: 0.0001, name: "shift_p_min", value: params.value.shift_p_min })}
|
||||
${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })}
|
||||
${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
||||
${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
||||
|
|
|
@ -309,6 +309,7 @@
|
|||
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
|
||||
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
k_shift: 0, // 0 = disabled
|
||||
shift_p_min: 0.0001, // 0 = disabled
|
||||
top_k: 40, // <= 0 to use vocab size
|
||||
top_p: 0.95, // 1.0 = disabled
|
||||
min_p: 0.05, // 0 = disabled
|
||||
|
@ -1008,7 +1009,8 @@
|
|||
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
|
||||
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
||||
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
|
||||
${IntField({ label: "K-shift", max: 100, min: -1, name: "k_shift", value: params.value.k_shift })}
|
||||
${IntField({ label: "K-Shift", max: 100, min: -1, name: "k_shift", value: params.value.k_shift })}
|
||||
${FloatField({ label: "Shift min probability", max: 1.0, min: 0.0, name: "shift_p_min", step: 0.0001, value: params.value.shift_p_min })}
|
||||
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
|
||||
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
||||
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
|
||||
|
|
|
@ -801,7 +801,8 @@ struct server_context {
|
|||
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
||||
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
|
||||
slot.sparams.k_shift = json_value(data, "k_shift", default_sparams.k_shift);
|
||||
slot.sparams.k_shift = json_value(data, "k_shift", default_sparams.k_shift);
|
||||
slot.sparams.shift_p_min = json_value(data, "shift_p_min", default_sparams.shift_p_min);
|
||||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||
|
@ -1142,6 +1143,7 @@ struct server_context {
|
|||
{"dynatemp_range", slot.sparams.dynatemp_range},
|
||||
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
|
||||
{"k_shift", slot.sparams.k_shift},
|
||||
{"shift_p_min", slot.sparams.shift_p_min},
|
||||
{"top_k", slot.sparams.top_k},
|
||||
{"top_p", slot.sparams.top_p},
|
||||
{"min_p", slot.sparams.min_p},
|
||||
|
|
|
@ -1098,7 +1098,7 @@ extern "C" {
|
|||
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
||||
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_k_shift (int32_t k);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_k_shift (int32_t k, float p_min);
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
|
|
|
@ -188,11 +188,14 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|||
cur_p->size = k;
|
||||
}
|
||||
|
||||
static void llama_sampler_top_shift_impl(llama_token_data_array * cur_p, int k) {
|
||||
static void llama_sampler_top_shift_impl(llama_token_data_array * cur_p, int k, float p_min) {
|
||||
// sort before shifting
|
||||
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
||||
// limit to the first token we define as coherent
|
||||
while (cur_p->data[k].p < p_min && k > 0) {
|
||||
--k;
|
||||
}
|
||||
|
||||
// shift to a token #[k]
|
||||
cur_p->data += k;
|
||||
|
@ -1097,6 +1100,7 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
|
|||
|
||||
struct llama_sampler_k_shift {
|
||||
const int32_t k;
|
||||
const float p_min;
|
||||
bool k_set;
|
||||
};
|
||||
|
||||
|
@ -1107,6 +1111,11 @@ static const char * llama_sampler_k_shift_name(const struct llama_sampler * /*sm
|
|||
static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_k_shift *) smpl->ctx;
|
||||
|
||||
// return early if minimum probability is impossible or makes k-shift useless
|
||||
if (ctx->p_min >= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// ensures that k-shift can happen on the first step only
|
||||
if (ctx->k_set != true) {
|
||||
ctx->k_set = true;
|
||||
|
@ -1118,13 +1127,13 @@ static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token
|
|||
return;
|
||||
}
|
||||
|
||||
llama_sampler_top_shift_impl(cur_p, ctx->k);
|
||||
llama_sampler_top_shift_impl(cur_p, ctx->k, ctx->p_min);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_k_shift_clone(const struct llama_sampler * smpl) {
|
||||
auto * ctx = (const llama_sampler_k_shift *) smpl->ctx;
|
||||
|
||||
return llama_sampler_init_k_shift(ctx->k);
|
||||
return llama_sampler_init_k_shift(ctx->k, ctx->p_min);
|
||||
}
|
||||
|
||||
static void llama_sampler_k_shift_free(struct llama_sampler * smpl) {
|
||||
|
@ -1145,11 +1154,12 @@ static struct llama_sampler_i llama_sampler_k_shift_i = {
|
|||
/* .free = */ llama_sampler_k_shift_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_k_shift(int32_t k) {
|
||||
struct llama_sampler * llama_sampler_init_k_shift(int32_t k, float p_min) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_k_shift_i,
|
||||
/* .ctx = */ new llama_sampler_k_shift {
|
||||
/* .k = */ k,
|
||||
/* .k = */ k,
|
||||
/* .p_min = */ p_min,
|
||||
/* .k_set = */ false,
|
||||
},
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue