From 87384fb5577afc16bdfff72aeebc7dc603528765 Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Fri, 25 Oct 2024 23:22:16 +0500 Subject: [PATCH 1/5] K-Shift commit --- common/arg.cpp | 7 +++ common/common.cpp | 1 + common/common.h | 19 ++++---- common/sampling.cpp | 12 ++++- examples/main/README.md | 8 ++++ examples/server/public/index-new.html | 4 +- examples/server/public/index.html | 2 + examples/server/server.cpp | 2 + include/llama.h | 3 ++ src/llama-sampling.cpp | 67 +++++++++++++++++++++++++++ 10 files changed, 114 insertions(+), 11 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index e1e933934..58cab7020 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -922,6 +922,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sparams.temp = std::max(params.sparams.temp, 0.0f); } ).set_sparam()); + add_opt(common_arg( + {"--k-shift"}, "N", + string_format("k-shift sampling (default: %d, 0 = disabled)", params.sparams.k_shift), + [](common_params & params, int value) { + params.sparams.k_shift = value; + } + ).set_sparam()); add_opt(common_arg( {"--top-k"}, "N", string_format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k), diff --git a/common/common.cpp b/common/common.cpp index ff8cc4076..490c089de 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2092,6 +2092,7 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency()); + fprintf(stream, "k_shift: %d # default: 0\n", sparams.k_shift); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); diff --git a/common/common.h b/common/common.h index 18b2121ed..c7b9a3ee2 100644 --- a/common/common.h +++ b/common/common.h @@ -85,14 +85,15 @@ enum llama_example { enum common_sampler_type { COMMON_SAMPLER_TYPE_NONE = 0, COMMON_SAMPLER_TYPE_DRY = 1, - COMMON_SAMPLER_TYPE_TOP_K = 2, - COMMON_SAMPLER_TYPE_TOP_P = 3, - COMMON_SAMPLER_TYPE_MIN_P = 4, - COMMON_SAMPLER_TYPE_TFS_Z = 5, - COMMON_SAMPLER_TYPE_TYPICAL_P = 6, - COMMON_SAMPLER_TYPE_TEMPERATURE = 7, - COMMON_SAMPLER_TYPE_XTC = 8, - COMMON_SAMPLER_TYPE_INFILL = 9, + COMMON_SAMPLER_TYPE_K_SHIFT = 2, + COMMON_SAMPLER_TYPE_TOP_K = 3, + COMMON_SAMPLER_TYPE_TOP_P = 4, + COMMON_SAMPLER_TYPE_MIN_P = 5, + COMMON_SAMPLER_TYPE_TFS_Z = 6, + COMMON_SAMPLER_TYPE_TYPICAL_P = 7, + COMMON_SAMPLER_TYPE_TEMPERATURE = 8, + COMMON_SAMPLER_TYPE_XTC = 9, + COMMON_SAMPLER_TYPE_INFILL = 10, }; // dimensionality reduction methods, used by cvector-generator @@ -108,6 +109,7 @@ struct common_sampler_params { int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t k_shift = 0; // 0 = disabled int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float min_p = 0.05f; // 0.0 = disabled @@ -138,6 +140,7 @@ struct common_sampler_params { std::vector samplers = { COMMON_SAMPLER_TYPE_DRY, + COMMON_SAMPLER_TYPE_K_SHIFT, COMMON_SAMPLER_TYPE_TOP_K, COMMON_SAMPLER_TYPE_TFS_Z, COMMON_SAMPLER_TYPE_TYPICAL_P, diff --git a/common/sampling.cpp b/common/sampling.cpp index 48a9df8ba..82827a616 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -131,11 +131,11 @@ std::string common_sampler_params::print() const { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" - "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" + "\tk_shift = %d, top_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, - top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, + k_shift, top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, mirostat, mirostat_eta, mirostat_tau); return std::string(result); @@ -187,6 +187,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; + case COMMON_SAMPLER_TYPE_K_SHIFT: + llama_sampler_chain_add(result->chain, llama_sampler_init_k_shift (params.k_shift)); + break; case COMMON_SAMPLER_TYPE_TOP_K: llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); break; @@ -372,6 +375,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: return 'd'; + case COMMON_SAMPLER_TYPE_K_SHIFT: return 's'; case COMMON_SAMPLER_TYPE_TOP_K: return 'k'; case COMMON_SAMPLER_TYPE_TFS_Z: return 'f'; case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y'; @@ -387,6 +391,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: return "dry"; + case COMMON_SAMPLER_TYPE_K_SHIFT: return "k_shift"; case COMMON_SAMPLER_TYPE_TOP_K: return "top_k"; case COMMON_SAMPLER_TYPE_TFS_Z: return "tfs_z"; case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; @@ -403,6 +408,7 @@ std::vector common_sampler_types_from_names(const std::vect std::unordered_map sampler_canonical_name_map { { "dry", COMMON_SAMPLER_TYPE_DRY }, { "top_k", COMMON_SAMPLER_TYPE_TOP_K }, + { "k_shift", COMMON_SAMPLER_TYPE_K_SHIFT }, { "top_p", COMMON_SAMPLER_TYPE_TOP_P }, { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "min_p", COMMON_SAMPLER_TYPE_MIN_P }, @@ -416,6 +422,7 @@ std::vector common_sampler_types_from_names(const std::vect // make it ready for both system names and input names std::unordered_map sampler_alt_name_map { { "top-k", COMMON_SAMPLER_TYPE_TOP_K }, + { "k-shift", COMMON_SAMPLER_TYPE_K_SHIFT }, { "top-p", COMMON_SAMPLER_TYPE_TOP_P }, { "nucleus", COMMON_SAMPLER_TYPE_TOP_P }, { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, @@ -451,6 +458,7 @@ std::vector common_sampler_types_from_names(const std::vect std::vector common_sampler_types_from_chars(const std::string & chars) { std::unordered_map sampler_name_map = { { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_K_SHIFT), COMMON_SAMPLER_TYPE_K_SHIFT }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TFS_Z), COMMON_SAMPLER_TYPE_TFS_Z }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P }, diff --git a/examples/main/README.md b/examples/main/README.md index c7c823171..2c7d45e0c 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -211,6 +211,14 @@ DRY sampling provides more nuanced control over text generation, particularly fo Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dry-penalty-last-n -1 --dry-sequence-breaker "—" --dry-sequence-breaker "##"` +### K-Shift Sampling + +- `--k-shift N`: Shift the first token selection by cutting out N tokens from the top once (default: 0). + +K-Shift is a sampling method that guides models away from the most obvious output, eliciting reasoning and analysis. It cuts out k top tokens once at the beginning of inference, making sure that the dialog will start from a less obvious path without guiding the model too much. The method was mentoned in a paper [Chain-of-Thought Reasoning without Prompting](https://arxiv.org/pdf/2402.10200) as a simple trick to guiding a model towards reasoning. In practice, K-Shift can improve the quality of reasoning, help bypass bias or censorship in certain cases, and may also be used as a diagnostics tool. K-Shift is intended to be used with greedy sampling (`--k-shift 10 --top-k 1`), but can help with creative writing too - albeit, not as much as XTC. The default value is 0. + +Example usage: `--k-shift 10` + ### Top-K Sampling - `--top-k N`: Limit the next token selection to the K most probable tokens (default: 40). diff --git a/examples/server/public/index-new.html b/examples/server/public/index-new.html index cb3995abe..f7a38bfea 100644 --- a/examples/server/public/index-new.html +++ b/examples/server/public/index-new.html @@ -44,7 +44,8 @@ dry_base: 1.75, // 0.0 = disabled dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) - top_k: 0, // <= 0 to use vocab size + k_shift: 0, // <= 0 to use vocab size + top_k: 0, // 0 = disabled top_p: 1.0, // 1.0 = disabled min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4 xtc_probability: 0.0, // 0 = disabled; @@ -835,6 +836,7 @@ return html`
Further Options
+ ${IntField({ label: "K-Shift", title: "Cuts out first k tokens once at the start of sampling. Intended to use with greedy sampling.", max: 100, min: 0, step: 1, name: "k_shift", value: params.value.k_shift })} ${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })} ${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })} ${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} diff --git a/examples/server/public/index.html b/examples/server/public/index.html index 7f9b02bfb..5d391f11a 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -308,6 +308,7 @@ dry_base: 1.75, // 0.0 = disabled dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + k_shift: 0, // 0 = disabled top_k: 40, // <= 0 to use vocab size top_p: 0.95, // 1.0 = disabled min_p: 0.05, // 0 = disabled @@ -1008,6 +1009,7 @@ ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })} + ${IntField({ label: "K-shift", max: 100, min: -1, name: "k_shift", value: params.value.k_shift })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })} diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ff1d9b03c..bec852cf2 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -804,6 +804,7 @@ struct server_context { slot.params.cache_prompt = json_value(data, "cache_prompt", false); slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent); + slot.sparams.k_shift = json_value(data, "k_shift", default_sparams.k_shift); slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); @@ -1144,6 +1145,7 @@ struct server_context { {"temperature", slot.sparams.temp}, {"dynatemp_range", slot.sparams.dynatemp_range}, {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, + {"k_shift", slot.sparams.k_shift}, {"top_k", slot.sparams.top_k}, {"top_p", slot.sparams.top_p}, {"min_p", slot.sparams.min_p}, diff --git a/include/llama.h b/include/llama.h index b2d1e7d5a..5837d74bc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1102,6 +1102,9 @@ extern "C" { /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); + + LLAMA_API struct llama_sampler * llama_sampler_init_k_shift (int32_t k); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 25536eb6c..45b6104ca 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -188,6 +188,17 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) cur_p->size = k; } +static void llama_sampler_top_shift_impl(llama_token_data_array * cur_p, int k) { + // sort before shifting + std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + + // shift to a token #[k] + cur_p->data += k; + cur_p->size -= k; +} + static uint32_t get_rng_seed(uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { // use system clock if std::random_device is not a true RNG @@ -1177,6 +1188,62 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, }; } +// k-shift + +struct llama_sampler_k_shift { + const int32_t k; + bool k_set; +}; + +static const char * llama_sampler_k_shift_name(const struct llama_sampler * /*smpl*/) { + return "k-shift"; +} + +static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_k_shift *) smpl->ctx; + + if (ctx->k <= 0 || ctx->k_set == true) { + return; + } + + llama_sampler_top_shift_impl(cur_p, ctx->k); + ctx->k_set = true; +} + +static struct llama_sampler * llama_sampler_k_shift_clone(const struct llama_sampler * smpl) { + auto * ctx = (const llama_sampler_k_shift *) smpl->ctx; + + return llama_sampler_init_k_shift(ctx->k); +} + +static void llama_sampler_k_shift_free(struct llama_sampler * smpl) { + delete (llama_sampler_k_shift *) smpl->ctx; +} + +static void llama_sampler_k_shift_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_k_shift *) smpl->ctx; + ctx->k_set = false; +} + +static struct llama_sampler_i llama_sampler_k_shift_i = { + /* .name = */ llama_sampler_k_shift_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_k_shift_apply, + /* .reset = */ llama_sampler_k_shift_reset, + /* .clone = */ llama_sampler_k_shift_clone, + /* .free = */ llama_sampler_k_shift_free, +}; + +struct llama_sampler * llama_sampler_init_k_shift(int32_t k) { + return new llama_sampler { + /* .iface = */ &llama_sampler_k_shift_i, + /* .ctx = */ new llama_sampler_k_shift { + /* .k = */ k, + /* .k_set = */ false, + }, + }; +} + // mirostat struct llama_sampler_mirostat { From 070f9546f6ddb15887501a9dcc15c1b1e9cb502d Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Fri, 25 Oct 2024 23:28:30 +0500 Subject: [PATCH 2/5] Fixed style --- common/common.h | 2 +- examples/server/server.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/common/common.h b/common/common.h index c7b9a3ee2..97e35fb4c 100644 --- a/common/common.h +++ b/common/common.h @@ -109,7 +109,7 @@ struct common_sampler_params { int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t k_shift = 0; // 0 = disabled + int32_t k_shift = 0; // 0 = disabled int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float min_p = 0.05f; // 0.0 = disabled diff --git a/examples/server/server.cpp b/examples/server/server.cpp index bec852cf2..a8fa34b19 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -804,7 +804,7 @@ struct server_context { slot.params.cache_prompt = json_value(data, "cache_prompt", false); slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent); - slot.sparams.k_shift = json_value(data, "k_shift", default_sparams.k_shift); + slot.sparams.k_shift = json_value(data, "k_shift", default_sparams.k_shift); slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); From 48b715da284fa273222e585f3d678eeca9f18e05 Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Sat, 26 Oct 2024 12:21:46 +0500 Subject: [PATCH 3/5] Fixes and tests --- src/llama-sampling.cpp | 4 +++- tests/test-sampling.cpp | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 45b6104ca..ba3701472 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1202,7 +1202,9 @@ static const char * llama_sampler_k_shift_name(const struct llama_sampler * /*sm static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_k_shift *) smpl->ctx; - if (ctx->k <= 0 || ctx->k_set == true) { + if (ctx->k_set == true + || ctx->k <= 0 + || ctx->k >= (int) cur_p->size) { return; } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index eb39661c3..22f74ac4c 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -83,6 +83,17 @@ static void test_temp_ext(const std::vector & probs, const std::vector & probs, const std::vector & probs_expected, int k) { + sampler_tester tester(probs, probs_expected); + + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_k_shift(k)); + tester.apply(llama_sampler_init_dist (0)); + DUMP(&tester.cur_p); + + tester.check(); +} + static void test_top_k(const std::vector & probs, const std::vector & probs_expected, int k) { sampler_tester tester(probs, probs_expected); @@ -299,6 +310,7 @@ static void test_perf() { data.emplace_back(llama_token_data{i, logit, 0.0f}); } + BENCH(llama_sampler_init_k_shift (10), data, 32); BENCH(llama_sampler_init_top_k (40), data, 32); BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32); BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32); @@ -316,6 +328,12 @@ int main(void) { test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f); test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f); + test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4); + test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 3); + test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.66666f, 0.33333f}, 2); + test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.5f, 0.33333f, 0.16666f}, 1); + test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0); + test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4); From 9ef8cb5a3e9e3ad55cf56b9bb619588cb5a20f9a Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Fri, 1 Nov 2024 14:15:05 +0500 Subject: [PATCH 4/5] Removed custom reset --- src/llama-sampling.cpp | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 2b0f907a8..73bec0b98 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1127,16 +1127,11 @@ static void llama_sampler_k_shift_free(struct llama_sampler * smpl) { delete (llama_sampler_k_shift *) smpl->ctx; } -static void llama_sampler_k_shift_reset(struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_k_shift *) smpl->ctx; - ctx->k_set = false; -} - static struct llama_sampler_i llama_sampler_k_shift_i = { /* .name = */ llama_sampler_k_shift_name, /* .accept = */ nullptr, /* .apply = */ llama_sampler_k_shift_apply, - /* .reset = */ llama_sampler_k_shift_reset, + /* .reset = */ nullptr, /* .clone = */ llama_sampler_k_shift_clone, /* .free = */ llama_sampler_k_shift_free, }; From f853c3eacf8845ce091081c86ac4c47a93f6d73c Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Fri, 1 Nov 2024 17:23:34 +0500 Subject: [PATCH 5/5] Revert back reset function --- src/llama-sampling.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 73bec0b98..2b0f907a8 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1127,11 +1127,16 @@ static void llama_sampler_k_shift_free(struct llama_sampler * smpl) { delete (llama_sampler_k_shift *) smpl->ctx; } +static void llama_sampler_k_shift_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_k_shift *) smpl->ctx; + ctx->k_set = false; +} + static struct llama_sampler_i llama_sampler_k_shift_i = { /* .name = */ llama_sampler_k_shift_name, /* .accept = */ nullptr, /* .apply = */ llama_sampler_k_shift_apply, - /* .reset = */ nullptr, + /* .reset = */ llama_sampler_k_shift_reset, /* .clone = */ llama_sampler_k_shift_clone, /* .free = */ llama_sampler_k_shift_free, };