From 9cae93cfd8b046809e594486e15c640f723490c3 Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Sat, 9 Nov 2024 14:54:49 +0500 Subject: [PATCH] Added `shift_p_min` parameter to control probabilities * the parameter limits how far K-Shift cuts by checking probability of the last token and iterating backwards if it's not probable enough --- common/arg.cpp | 9 ++++- common/common.h | 55 ++++++++++++++------------- common/sampling.cpp | 2 +- examples/main/README.md | 3 +- examples/server/public/index-new.html | 2 + examples/server/public/index.html | 4 +- examples/server/server.cpp | 4 +- include/llama.h | 2 +- src/llama-sampling.cpp | 26 +++++++++---- 9 files changed, 66 insertions(+), 41 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index fff63cbe6..652b65897 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -924,11 +924,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--k-shift"}, "N", - string_format("k-shift sampling (default: %d, 0 = disabled)", params.sparams.k_shift), + string_format("K-Shift sampling (default: %d, 0 = disabled)", params.sparams.k_shift), [](common_params & params, int value) { params.sparams.k_shift = value; } ).set_sparam()); + add_opt(common_arg( + {"--shift-p-min"}, "N", + string_format("minimum probability required for tokens to be cut by k-shift (default: %.4f, 1.0 = disabled)", params.sparams.shift_p_min), + [](common_params & params, int value) { + params.sparams.shift_p_min = value; + } + ).set_sparam()); add_opt(common_arg( {"--top-k"}, "N", string_format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k), diff --git a/common/common.h b/common/common.h index 9008741f8..d48e5b5c1 100644 --- a/common/common.h +++ b/common/common.h @@ -104,35 +104,36 @@ enum dimre_method { // sampler parameters struct common_sampler_params { - uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t k_shift = 0; // 0 = disabled - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float xtc_probability = 0.00f; // 0.0 = disabled - float xtc_threshold = 0.10f; // > 0.5 disables XTC - float typ_p = 1.00f; // typical_p, 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: - float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) - int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty - int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t k_shift = 0; // 0 = disabled + float shift_p_min = 0.0001; // >= 1 disables k-shift + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float xtc_probability = 0.00f; // 0.0 = disabled + float xtc_threshold = 0.10f; // > 0.5 disables XTC + float typ_p = 1.00f; // typical_p, 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: + float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) + int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty + int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = false; // consider newlines as a repeatable token bool ignore_eos = false; - bool no_perf = false; // disable performance metrics + bool no_perf = false; // disable performance metrics std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY diff --git a/common/sampling.cpp b/common/sampling.cpp index abf088452..b96bf9231 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -188,7 +188,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } break; case COMMON_SAMPLER_TYPE_K_SHIFT: - llama_sampler_chain_add(result->chain, llama_sampler_init_k_shift (params.k_shift)); + llama_sampler_chain_add(result->chain, llama_sampler_init_k_shift (params.k_shift, params.shift_p_min)); break; case COMMON_SAMPLER_TYPE_TOP_K: llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); diff --git a/examples/main/README.md b/examples/main/README.md index ca95d4ee3..ba54263ef 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -214,10 +214,11 @@ Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dr ### K-Shift Sampling - `--k-shift N`: Shift the first token selection by cutting out N tokens from the top once (default: 0). +- `--shift-p-min N`: Sets the minimum probability for tokens that can be cut, ensuring coherency (default: 0.0001). K-Shift is a sampling method that guides models away from the most obvious output, eliciting reasoning and analysis. It cuts out k top tokens once at the beginning of inference, making sure that the dialog will start from a less obvious path without guiding the model too much. The method was mentoned in a paper [Chain-of-Thought Reasoning without Prompting](https://arxiv.org/pdf/2402.10200) as a simple trick to guiding a model towards reasoning. In practice, K-Shift can improve the quality of reasoning, help bypass bias or censorship in certain cases, and may also be used as a diagnostics tool. K-Shift is intended to be used with greedy sampling (`--k-shift 10 --top-k 1`), but can help with creative writing too - albeit, not as much as XTC. The default value is 0. -Example usage: `--k-shift 10` +Example usage: `--k-shift 10 --shift-p-min 0.001` ### Top-K Sampling diff --git a/examples/server/public/index-new.html b/examples/server/public/index-new.html index 06f4ef7b4..c22720ad3 100644 --- a/examples/server/public/index-new.html +++ b/examples/server/public/index-new.html @@ -45,6 +45,7 @@ dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) k_shift: 0, // <= 0 to use vocab size + shift_p_min: 0.0001, // <= 0 to use vocab size top_k: 0, // 0 = disabled top_p: 1.0, // 1.0 = disabled min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4 @@ -836,6 +837,7 @@ return html` Further Options
${IntField({ label: "K-Shift", title: "Cuts out first k tokens once at the start of sampling. Intended to use with greedy sampling.", max: 100, min: 0, step: 1, name: "k_shift", value: params.value.k_shift })} + ${FloatField({ label: "Shift min probability", title: "Sets the minimum probability limit for tokens cut by K-Shift.", max: 1.0, min: 0.0, step: 0.0001, name: "shift_p_min", value: params.value.shift_p_min })} ${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })} ${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })} ${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} diff --git a/examples/server/public/index.html b/examples/server/public/index.html index 3e486bf9d..a81518176 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -309,6 +309,7 @@ dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) k_shift: 0, // 0 = disabled + shift_p_min: 0.0001, // 0 = disabled top_k: 40, // <= 0 to use vocab size top_p: 0.95, // 1.0 = disabled min_p: 0.05, // 0 = disabled @@ -1008,7 +1009,8 @@ ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })} - ${IntField({ label: "K-shift", max: 100, min: -1, name: "k_shift", value: params.value.k_shift })} + ${IntField({ label: "K-Shift", max: 100, min: -1, name: "k_shift", value: params.value.k_shift })} + ${FloatField({ label: "Shift min probability", max: 1.0, min: 0.0, name: "shift_p_min", step: 0.0001, value: params.value.shift_p_min })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })} diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6c21fd36a..587237c8a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -801,7 +801,8 @@ struct server_context { slot.params.cache_prompt = json_value(data, "cache_prompt", false); slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent); - slot.sparams.k_shift = json_value(data, "k_shift", default_sparams.k_shift); + slot.sparams.k_shift = json_value(data, "k_shift", default_sparams.k_shift); + slot.sparams.shift_p_min = json_value(data, "shift_p_min", default_sparams.shift_p_min); slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); @@ -1142,6 +1143,7 @@ struct server_context { {"dynatemp_range", slot.sparams.dynatemp_range}, {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, {"k_shift", slot.sparams.k_shift}, + {"shift_p_min", slot.sparams.shift_p_min}, {"top_k", slot.sparams.top_k}, {"top_p", slot.sparams.top_p}, {"min_p", slot.sparams.min_p}, diff --git a/include/llama.h b/include/llama.h index 74b4681e0..96ba18890 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1098,7 +1098,7 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); - LLAMA_API struct llama_sampler * llama_sampler_init_k_shift (int32_t k); + LLAMA_API struct llama_sampler * llama_sampler_init_k_shift (int32_t k, float p_min); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 59786f708..9edc47fe5 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -188,11 +188,14 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) cur_p->size = k; } -static void llama_sampler_top_shift_impl(llama_token_data_array * cur_p, int k) { +static void llama_sampler_top_shift_impl(llama_token_data_array * cur_p, int k, float p_min) { // sort before shifting - std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); + llama_sampler_softmax_impl(cur_p); + + // limit to the first token we define as coherent + while (cur_p->data[k].p < p_min && k > 0) { + --k; + } // shift to a token #[k] cur_p->data += k; @@ -1097,6 +1100,7 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, struct llama_sampler_k_shift { const int32_t k; + const float p_min; bool k_set; }; @@ -1107,6 +1111,11 @@ static const char * llama_sampler_k_shift_name(const struct llama_sampler * /*sm static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_k_shift *) smpl->ctx; + // return early if minimum probability is impossible or makes k-shift useless + if (ctx->p_min >= 1) { + return; + } + // ensures that k-shift can happen on the first step only if (ctx->k_set != true) { ctx->k_set = true; @@ -1118,13 +1127,13 @@ static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token return; } - llama_sampler_top_shift_impl(cur_p, ctx->k); + llama_sampler_top_shift_impl(cur_p, ctx->k, ctx->p_min); } static struct llama_sampler * llama_sampler_k_shift_clone(const struct llama_sampler * smpl) { auto * ctx = (const llama_sampler_k_shift *) smpl->ctx; - return llama_sampler_init_k_shift(ctx->k); + return llama_sampler_init_k_shift(ctx->k, ctx->p_min); } static void llama_sampler_k_shift_free(struct llama_sampler * smpl) { @@ -1145,11 +1154,12 @@ static struct llama_sampler_i llama_sampler_k_shift_i = { /* .free = */ llama_sampler_k_shift_free, }; -struct llama_sampler * llama_sampler_init_k_shift(int32_t k) { +struct llama_sampler * llama_sampler_init_k_shift(int32_t k, float p_min) { return new llama_sampler { /* .iface = */ &llama_sampler_k_shift_i, /* .ctx = */ new llama_sampler_k_shift { - /* .k = */ k, + /* .k = */ k, + /* .p_min = */ p_min, /* .k_set = */ false, }, };