diff --git a/common/arg.cpp b/common/arg.cpp index 7c5c5e5cd..fff63cbe6 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -922,6 +922,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sparams.temp = std::max(params.sparams.temp, 0.0f); } ).set_sparam()); + add_opt(common_arg( + {"--k-shift"}, "N", + string_format("k-shift sampling (default: %d, 0 = disabled)", params.sparams.k_shift), + [](common_params & params, int value) { + params.sparams.k_shift = value; + } + ).set_sparam()); add_opt(common_arg( {"--top-k"}, "N", string_format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k), diff --git a/common/common.cpp b/common/common.cpp index 19674af15..f90215358 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2096,6 +2096,7 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector); fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency()); + fprintf(stream, "k_shift: %d # default: 0\n", sparams.k_shift); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); diff --git a/common/common.h b/common/common.h index 727f85baa..9008741f8 100644 --- a/common/common.h +++ b/common/common.h @@ -85,14 +85,15 @@ enum llama_example { enum common_sampler_type { COMMON_SAMPLER_TYPE_NONE = 0, COMMON_SAMPLER_TYPE_DRY = 1, - COMMON_SAMPLER_TYPE_TOP_K = 2, - COMMON_SAMPLER_TYPE_TOP_P = 3, - COMMON_SAMPLER_TYPE_MIN_P = 4, - //COMMON_SAMPLER_TYPE_TFS_Z = 5, - COMMON_SAMPLER_TYPE_TYPICAL_P = 6, - COMMON_SAMPLER_TYPE_TEMPERATURE = 7, - COMMON_SAMPLER_TYPE_XTC = 8, - COMMON_SAMPLER_TYPE_INFILL = 9, + COMMON_SAMPLER_TYPE_K_SHIFT = 2, + COMMON_SAMPLER_TYPE_TOP_K = 3, + COMMON_SAMPLER_TYPE_TOP_P = 4, + COMMON_SAMPLER_TYPE_MIN_P = 5, + //COMMON_SAMPLER_TYPE_TFS_Z = 6, + COMMON_SAMPLER_TYPE_TYPICAL_P = 7, + COMMON_SAMPLER_TYPE_TEMPERATURE = 8, + COMMON_SAMPLER_TYPE_XTC = 9, + COMMON_SAMPLER_TYPE_INFILL = 10, }; // dimensionality reduction methods, used by cvector-generator @@ -108,6 +109,7 @@ struct common_sampler_params { int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t k_shift = 0; // 0 = disabled int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float min_p = 0.05f; // 0.0 = disabled @@ -137,6 +139,7 @@ struct common_sampler_params { std::vector samplers = { COMMON_SAMPLER_TYPE_DRY, + COMMON_SAMPLER_TYPE_K_SHIFT, COMMON_SAMPLER_TYPE_TOP_K, COMMON_SAMPLER_TYPE_TYPICAL_P, COMMON_SAMPLER_TYPE_TOP_P, diff --git a/common/sampling.cpp b/common/sampling.cpp index 7922fde47..abf088452 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -131,11 +131,11 @@ std::string common_sampler_params::print() const { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" - "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" + "\tk_shift = %d, top_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, - top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, + k_shift, top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, mirostat, mirostat_eta, mirostat_tau); return std::string(result); @@ -187,6 +187,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; + case COMMON_SAMPLER_TYPE_K_SHIFT: + llama_sampler_chain_add(result->chain, llama_sampler_init_k_shift (params.k_shift)); + break; case COMMON_SAMPLER_TYPE_TOP_K: llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); break; @@ -369,6 +372,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: return 'd'; + case COMMON_SAMPLER_TYPE_K_SHIFT: return 's'; case COMMON_SAMPLER_TYPE_TOP_K: return 'k'; case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y'; case COMMON_SAMPLER_TYPE_TOP_P: return 'p'; @@ -383,6 +387,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: return "dry"; + case COMMON_SAMPLER_TYPE_K_SHIFT: return "k_shift"; case COMMON_SAMPLER_TYPE_TOP_K: return "top_k"; case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; case COMMON_SAMPLER_TYPE_TOP_P: return "top_p"; @@ -398,6 +403,7 @@ std::vector common_sampler_types_from_names(const std::vect std::unordered_map sampler_canonical_name_map { { "dry", COMMON_SAMPLER_TYPE_DRY }, { "top_k", COMMON_SAMPLER_TYPE_TOP_K }, + { "k_shift", COMMON_SAMPLER_TYPE_K_SHIFT }, { "top_p", COMMON_SAMPLER_TYPE_TOP_P }, { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "min_p", COMMON_SAMPLER_TYPE_MIN_P }, @@ -410,6 +416,7 @@ std::vector common_sampler_types_from_names(const std::vect // make it ready for both system names and input names std::unordered_map sampler_alt_name_map { { "top-k", COMMON_SAMPLER_TYPE_TOP_K }, + { "k-shift", COMMON_SAMPLER_TYPE_K_SHIFT }, { "top-p", COMMON_SAMPLER_TYPE_TOP_P }, { "nucleus", COMMON_SAMPLER_TYPE_TOP_P }, { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, @@ -443,6 +450,7 @@ std::vector common_sampler_types_from_names(const std::vect std::vector common_sampler_types_from_chars(const std::string & chars) { std::unordered_map sampler_name_map = { { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_K_SHIFT), COMMON_SAMPLER_TYPE_K_SHIFT }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P }, diff --git a/examples/main/README.md b/examples/main/README.md index 145216938..ca95d4ee3 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -211,6 +211,14 @@ DRY sampling provides more nuanced control over text generation, particularly fo Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dry-penalty-last-n -1 --dry-sequence-breaker "—" --dry-sequence-breaker "##"` +### K-Shift Sampling + +- `--k-shift N`: Shift the first token selection by cutting out N tokens from the top once (default: 0). + +K-Shift is a sampling method that guides models away from the most obvious output, eliciting reasoning and analysis. It cuts out k top tokens once at the beginning of inference, making sure that the dialog will start from a less obvious path without guiding the model too much. The method was mentoned in a paper [Chain-of-Thought Reasoning without Prompting](https://arxiv.org/pdf/2402.10200) as a simple trick to guiding a model towards reasoning. In practice, K-Shift can improve the quality of reasoning, help bypass bias or censorship in certain cases, and may also be used as a diagnostics tool. K-Shift is intended to be used with greedy sampling (`--k-shift 10 --top-k 1`), but can help with creative writing too - albeit, not as much as XTC. The default value is 0. + +Example usage: `--k-shift 10` + ### Top-K Sampling - `--top-k N`: Limit the next token selection to the K most probable tokens (default: 40). diff --git a/examples/server/public/index-new.html b/examples/server/public/index-new.html index 8bfa380e5..06f4ef7b4 100644 --- a/examples/server/public/index-new.html +++ b/examples/server/public/index-new.html @@ -44,7 +44,8 @@ dry_base: 1.75, // 0.0 = disabled dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) - top_k: 0, // <= 0 to use vocab size + k_shift: 0, // <= 0 to use vocab size + top_k: 0, // 0 = disabled top_p: 1.0, // 1.0 = disabled min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4 xtc_probability: 0.0, // 0 = disabled; @@ -834,6 +835,7 @@ return html`
Further Options
+ ${IntField({ label: "K-Shift", title: "Cuts out first k tokens once at the start of sampling. Intended to use with greedy sampling.", max: 100, min: 0, step: 1, name: "k_shift", value: params.value.k_shift })} ${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })} ${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })} ${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} diff --git a/examples/server/public/index.html b/examples/server/public/index.html index a95f5c6df..3e486bf9d 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -308,6 +308,7 @@ dry_base: 1.75, // 0.0 = disabled dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + k_shift: 0, // 0 = disabled top_k: 40, // <= 0 to use vocab size top_p: 0.95, // 1.0 = disabled min_p: 0.05, // 0 = disabled @@ -1007,6 +1008,7 @@ ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })} + ${IntField({ label: "K-shift", max: 100, min: -1, name: "k_shift", value: params.value.k_shift })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })} diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1c7f0fd1d..6c21fd36a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -801,6 +801,7 @@ struct server_context { slot.params.cache_prompt = json_value(data, "cache_prompt", false); slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent); + slot.sparams.k_shift = json_value(data, "k_shift", default_sparams.k_shift); slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); @@ -1140,6 +1141,7 @@ struct server_context { {"temperature", slot.sparams.temp}, {"dynatemp_range", slot.sparams.dynatemp_range}, {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, + {"k_shift", slot.sparams.k_shift}, {"top_k", slot.sparams.top_k}, {"top_p", slot.sparams.top_p}, {"min_p", slot.sparams.min_p}, diff --git a/include/llama.h b/include/llama.h index ccb48f73c..74b4681e0 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1097,6 +1097,9 @@ extern "C" { /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); + + LLAMA_API struct llama_sampler * llama_sampler_init_k_shift (int32_t k); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index fd8ca8a9e..03ef703e9 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -188,6 +188,17 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) cur_p->size = k; } +static void llama_sampler_top_shift_impl(llama_token_data_array * cur_p, int k) { + // sort before shifting + std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + + // shift to a token #[k] + cur_p->data += k; + cur_p->size -= k; +} + static uint32_t get_rng_seed(uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { // use system clock if std::random_device is not a true RNG @@ -1082,6 +1093,64 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, }; } +// k-shift + +struct llama_sampler_k_shift { + const int32_t k; + bool k_set; +}; + +static const char * llama_sampler_k_shift_name(const struct llama_sampler * /*smpl*/) { + return "k-shift"; +} + +static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_k_shift *) smpl->ctx; + + if (ctx->k_set == true + || ctx->k <= 0 + || ctx->k >= (int) cur_p->size) { + return; + } + + llama_sampler_top_shift_impl(cur_p, ctx->k); + ctx->k_set = true; +} + +static struct llama_sampler * llama_sampler_k_shift_clone(const struct llama_sampler * smpl) { + auto * ctx = (const llama_sampler_k_shift *) smpl->ctx; + + return llama_sampler_init_k_shift(ctx->k); +} + +static void llama_sampler_k_shift_free(struct llama_sampler * smpl) { + delete (llama_sampler_k_shift *) smpl->ctx; +} + +static void llama_sampler_k_shift_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_k_shift *) smpl->ctx; + ctx->k_set = false; +} + +static struct llama_sampler_i llama_sampler_k_shift_i = { + /* .name = */ llama_sampler_k_shift_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_k_shift_apply, + /* .reset = */ llama_sampler_k_shift_reset, + /* .clone = */ llama_sampler_k_shift_clone, + /* .free = */ llama_sampler_k_shift_free, +}; + +struct llama_sampler * llama_sampler_init_k_shift(int32_t k) { + return new llama_sampler { + /* .iface = */ &llama_sampler_k_shift_i, + /* .ctx = */ new llama_sampler_k_shift { + /* .k = */ k, + /* .k_set = */ false, + }, + }; +} + // mirostat struct llama_sampler_mirostat { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index be370044d..fbaf5528d 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -83,6 +83,17 @@ static void test_temp_ext(const std::vector & probs, const std::vector & probs, const std::vector & probs_expected, int k) { + sampler_tester tester(probs, probs_expected); + + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_k_shift(k)); + tester.apply(llama_sampler_init_dist (0)); + DUMP(&tester.cur_p); + + tester.check(); +} + static void test_top_k(const std::vector & probs, const std::vector & probs_expected, int k) { sampler_tester tester(probs, probs_expected); @@ -288,11 +299,13 @@ static void test_perf() { data.emplace_back(llama_token_data{i, logit, 0.0f}); } - BENCH(llama_sampler_init_top_k (40), data, 32); - BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32); - BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32); - BENCH(llama_sampler_init_typical(0.5f, 1), data, 32); - BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32); + + BENCH(llama_sampler_init_k_shift (10), data, 32); + BENCH(llama_sampler_init_top_k (40), data, 32); + BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32); + BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32); + BENCH(llama_sampler_init_typical (0.5f, 1), data, 32); + BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32); } int main(void) { @@ -304,6 +317,12 @@ int main(void) { test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f); test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f); + test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4); + test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 3); + test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.66666f, 0.33333f}, 2); + test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.5f, 0.33333f, 0.16666f}, 1); + test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0); + test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);