From ddc3c2208acf0fb5a05f28205f1291486f922822 Mon Sep 17 00:00:00 2001 From: VJHack Date: Thu, 9 Jan 2025 23:04:28 -0600 Subject: [PATCH 01/14] initial sampling changes: --- common/common.h | 3 +++ common/sampling.cpp | 27 +++++++++++++++++++-------- src/llama-sampling.cpp | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/common/common.h b/common/common.h index 0d452cf0f..3671751e9 100644 --- a/common/common.h +++ b/common/common.h @@ -95,6 +95,7 @@ enum common_sampler_type { COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_PENALTIES = 10, + COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11 }; // dimensionality reduction methods, used by cvector-generator @@ -128,6 +129,7 @@ struct common_params_sampling { int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + int32_t top_n_sigma = 2; float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; @@ -146,6 +148,7 @@ struct common_params_sampling { COMMON_SAMPLER_TYPE_MIN_P, COMMON_SAMPLER_TYPE_XTC, COMMON_SAMPLER_TYPE_TEMPERATURE, + COMMON_SAMPLER_TYPE_TOP_N_SIGMA, }; std::string grammar; // optional BNF-like grammar to constrain sampling diff --git a/common/sampling.cpp b/common/sampling.cpp index e83a971c7..15b08fe70 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -176,28 +176,32 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } break; case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); break; case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); break; case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); break; case COMMON_SAMPLER_TYPE_PENALTIES: - llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: + // llama_sampler_chain_add(result->chain, ) + llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma)) break; default: GGML_ASSERT(false && "unknown sampler type"); @@ -407,6 +411,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_INFILL: return 'i'; case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's'; default : return '?'; } } @@ -422,6 +427,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_INFILL: return "infill"; case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma"; default : return ""; } } @@ -437,6 +443,7 @@ std::vector common_sampler_types_from_names(const std::vect { "xtc", COMMON_SAMPLER_TYPE_XTC }, { "infill", COMMON_SAMPLER_TYPE_INFILL }, { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, + { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, }; // since samplers names are written multiple ways @@ -451,6 +458,9 @@ std::vector common_sampler_types_from_names(const std::vect { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "min-p", COMMON_SAMPLER_TYPE_MIN_P }, { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, + { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, + { "top-nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, + { "top_nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, }; std::vector samplers; @@ -484,6 +494,7 @@ std::vector common_sampler_types_from_chars(const std::stri { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA} }; std::vector samplers; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index ef5a576cc..d4e5e9be7 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1645,6 +1645,48 @@ struct llama_sampler * llama_sampler_init_penalties( }; } +// top-n-sigma + +struct llama_sampler_top_n_sigma { + const int32_t n; +}; + +static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { + return "top-n-sigma"; +} + +static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; + llama_sampler_top_n_sigma_impl(cur_p, ctx->n); +} + +// static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) { +// const auto * ctx = (const llama_sampler_top_k *) smpl->ctx; +// return llama_sampler_init_top_k(ctx->k); +// } + +// static void llama_sampler_top_k_free(struct llama_sampler * smpl) { +// delete (llama_sampler_top_k *) smpl->ctx; +// } + +// static struct llama_sampler_i llama_sampler_top_k_i = { +// /* .name = */ llama_sampler_top_k_name, +// /* .accept = */ nullptr, +// /* .apply = */ llama_sampler_top_k_apply, +// /* .reset = */ nullptr, +// /* .clone = */ llama_sampler_top_k_clone, +// /* .free = */ llama_sampler_top_k_free, +// }; + +// struct llama_sampler * llama_sampler_init_top_k(int32_t k) { +// return new llama_sampler { +// /* .iface = */ &llama_sampler_top_k_i, +// /* .ctx = */ new llama_sampler_top_k { +// /* .k = */ k, +// }, +// }; +// } + // DRY struct llama_sampler_dry { From da038d8715c68fb02f62059dcbf52882b501ad39 Mon Sep 17 00:00:00 2001 From: VJHack Date: Mon, 13 Jan 2025 14:46:12 -0600 Subject: [PATCH 02/14] completed top nsigma sampler implementation --- common/arg.cpp | 7 +++ common/common.h | 4 +- common/sampling.cpp | 98 ++++++++++++++++++++---------------------- include/llama.h | 3 ++ src/llama-sampling.cpp | 79 +++++++++++++++++++++++----------- 5 files changed, 112 insertions(+), 79 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 27886b84e..6c6be6ef7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -899,6 +899,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.min_p = std::stof(value); } ).set_sparam()); + add_opt(common_arg( + {"--top-nsigma"}, "N", + string_format("top-n-sigma sampling (default: %d, -1 = disabled)", params.sampling.top_n_sigma), + [](common_params & params, const std::string & value) { + params.sampling.top_n_sigma = std::stof(value); + } + ).set_sparam()); add_opt(common_arg( {"--xtc-probability"}, "N", string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), diff --git a/common/common.h b/common/common.h index 3671751e9..50f1def67 100644 --- a/common/common.h +++ b/common/common.h @@ -95,7 +95,6 @@ enum common_sampler_type { COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_PENALTIES = 10, - COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11 }; // dimensionality reduction methods, used by cvector-generator @@ -129,7 +128,7 @@ struct common_params_sampling { int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - int32_t top_n_sigma = 2; + int32_t top_n_sigma = -1; // -1 = disabled float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; @@ -148,7 +147,6 @@ struct common_params_sampling { COMMON_SAMPLER_TYPE_MIN_P, COMMON_SAMPLER_TYPE_XTC, COMMON_SAMPLER_TYPE_TEMPERATURE, - COMMON_SAMPLER_TYPE_TOP_N_SIGMA, }; std::string grammar; // optional BNF-like grammar to constrain sampling diff --git a/common/sampling.cpp b/common/sampling.cpp index 15b08fe70..9d58c1680 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -131,11 +131,11 @@ std::string common_params_sampling::print() const { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" - "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" - "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %d, temp = %.3f\n" + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f,", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, - top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, + top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, mirostat, mirostat_eta, mirostat_tau); return std::string(result); @@ -162,49 +162,50 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co params.logit_bias.data())); if (params.mirostat == 0) { - for (const auto & cnstr : params.samplers) { - switch (cnstr) { - case COMMON_SAMPLER_TYPE_DRY: - { - std::vector c_breakers; - c_breakers.reserve(params.dry_sequence_breakers.size()); - for (const auto & str : params.dry_sequence_breakers) { - c_breakers.push_back(str.c_str()); - } + if(params.top_n_sigma >= 0) { + llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma)); + } else { + for (const auto & cnstr : params.samplers) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: + { + std::vector c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto & str : params.dry_sequence_breakers) { + c_breakers.push_back(str.c_str()); + } - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); - } - break; - case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); - break; - case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); - break; - case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); - break; - case COMMON_SAMPLER_TYPE_PENALTIES: - llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); - break; - case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: - // llama_sampler_chain_add(result->chain, ) - llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma)) - break; - default: - GGML_ASSERT(false && "unknown sampler type"); + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + } + break; + case COMMON_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + break; + case COMMON_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_XTC: + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + case COMMON_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); + break; + case COMMON_SAMPLER_TYPE_PENALTIES: + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + default: + GGML_ASSERT(false && "unknown sampler type"); + } } } llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); @@ -411,7 +412,6 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_INFILL: return 'i'; case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; - case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's'; default : return '?'; } } @@ -427,7 +427,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_INFILL: return "infill"; case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; - case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma"; default : return ""; } } @@ -443,7 +442,6 @@ std::vector common_sampler_types_from_names(const std::vect { "xtc", COMMON_SAMPLER_TYPE_XTC }, { "infill", COMMON_SAMPLER_TYPE_INFILL }, { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, - { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, }; // since samplers names are written multiple ways @@ -458,9 +456,6 @@ std::vector common_sampler_types_from_names(const std::vect { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "min-p", COMMON_SAMPLER_TYPE_MIN_P }, { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, - { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, - { "top-nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, - { "top_nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, }; std::vector samplers; @@ -494,7 +489,6 @@ std::vector common_sampler_types_from_chars(const std::stri { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, - { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA} }; std::vector samplers; diff --git a/include/llama.h b/include/llama.h index 0295a51fb..7100d1ab0 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1133,6 +1133,9 @@ extern "C" { /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); + /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 + LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d4e5e9be7..1eb7df950 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -301,6 +301,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) cur_p->size = k; } + static uint32_t get_rng_seed(uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { // use system clock if std::random_device is not a true RNG @@ -1657,35 +1658,65 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; - llama_sampler_top_n_sigma_impl(cur_p, ctx->n); + // 1. Find max logit: M + // 2. Find standard deviation of logits: sig + // 3. Create a mask where m[i] = 1 if ith logit >= M - n (sig), else m[i] = 0 + // 4. Apply mask: ith logit itself if m[i]==1, else ith logit = -inf + // 5. p = softmax(l) + + // find max logit and calculate mean + int32_t max = cur_p->data[0].logit; + int32_t logits_sum = 0; + for (size_t i = 0; i < cur_p->size; ++i) { + if(cur_p->data[i].logit > max){ + max = cur_p->data[i].logit; + } + logits_sum += cur_p->data[i].logit; + } + int32_t mean = logits_sum/cur_p->size; + + // calculate standard deviation + int32_t acc = 0; + for(size_t i = 0; i < cur_p->size; ++i){ + acc += (cur_p->data[i].logit - mean) * (cur_p->data[i].logit - mean); + } + int32_t std = sqrt(acc/cur_p->size); + + //apply mask + for(size_t i = 0; i < cur_p->size; ++i){ + if(cur_p->data[i].logit < max - (ctx->n * std)) { + cur_p->data[i].logit = -INFINITY; + } + } + llama_sampler_softmax_impl(cur_p); } -// static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) { -// const auto * ctx = (const llama_sampler_top_k *) smpl->ctx; -// return llama_sampler_init_top_k(ctx->k); -// } +static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl){ + const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx; + return llama_sampler_init_top_n_sigma(ctx->n); +} -// static void llama_sampler_top_k_free(struct llama_sampler * smpl) { -// delete (llama_sampler_top_k *) smpl->ctx; -// } +static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_n_sigma *) smpl->ctx; +} -// static struct llama_sampler_i llama_sampler_top_k_i = { -// /* .name = */ llama_sampler_top_k_name, -// /* .accept = */ nullptr, -// /* .apply = */ llama_sampler_top_k_apply, -// /* .reset = */ nullptr, -// /* .clone = */ llama_sampler_top_k_clone, -// /* .free = */ llama_sampler_top_k_free, -// }; +static struct llama_sampler_i llama_sampler_top_n_sigma_i = { + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, +}; -// struct llama_sampler * llama_sampler_init_top_k(int32_t k) { -// return new llama_sampler { -// /* .iface = */ &llama_sampler_top_k_i, -// /* .ctx = */ new llama_sampler_top_k { -// /* .k = */ k, -// }, -// }; -// } +struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n) { + return new llama_sampler { + /* .iface = */ &llama_sampler_top_n_sigma_i, + /* .ctx = */ new llama_sampler_top_n_sigma { + /* .n = */ n, + }, + }; +} // DRY From bee4c7c9fa0e44a70ef8802e2d0f86a29b8498d9 Mon Sep 17 00:00:00 2001 From: VJHack Date: Mon, 13 Jan 2025 15:12:50 -0600 Subject: [PATCH 03/14] apply parameter to only llama-cli --- common/arg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 6c6be6ef7..c62e44168 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -905,7 +905,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.sampling.top_n_sigma = std::stof(value); } - ).set_sparam()); + ).set_examples({LLAMA_EXAMPLE_MAIN}).set_sparam()); add_opt(common_arg( {"--xtc-probability"}, "N", string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), From 8fb681bf9ae94eee631f87abdb4f6175d8951ce9 Mon Sep 17 00:00:00 2001 From: VJHack Date: Mon, 13 Jan 2025 16:17:39 -0600 Subject: [PATCH 04/14] updated readme --- examples/main/README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/main/README.md b/examples/main/README.md index 17d80a622..9591b1b0a 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -265,7 +265,15 @@ Being experimental and unique, XTC is disabled by default. The recommended combi Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1` -### Logit Bias +### Top-nσ Sampling + +- `--top-nsigma N`: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1, -1 = disabled). + +Top-nσ sampling is a text generation method that selects tokens based on a statistical threshold in pre-softmax logits. It works by only sampling from tokens with logits that are within n * σ of the maximum logit. This method helps maintain a stable sampling space regardless of temperature scaling, allowing it to perform well on reasoning tasks even in high temperatures. Without complex probability manipulation, it efficiently filters tokens directly on the pre-softmax logits. A higher value for top-nsigma (e.g., 5) will take more noisy tokens into consideration, while a lower value (e.g., 1) will focous on the more informative region of the sampling space. + +Example usage: `--top-nsigma 1` + +### Logit Bias - `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion. From 54ef105c85b1220908fa40c2b59f768a0392139b Mon Sep 17 00:00:00 2001 From: VJHack Date: Mon, 13 Jan 2025 17:48:35 -0600 Subject: [PATCH 05/14] added tests and fixed nsigma impl --- src/llama-sampling.cpp | 22 +++++++++------------- tests/test-sampling.cpp | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 1eb7df950..e91401d75 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1655,36 +1655,32 @@ struct llama_sampler_top_n_sigma { static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { return "top-n-sigma"; } +#include static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; - // 1. Find max logit: M - // 2. Find standard deviation of logits: sig - // 3. Create a mask where m[i] = 1 if ith logit >= M - n (sig), else m[i] = 0 - // 4. Apply mask: ith logit itself if m[i]==1, else ith logit = -inf - // 5. p = softmax(l) // find max logit and calculate mean - int32_t max = cur_p->data[0].logit; - int32_t logits_sum = 0; + float max = cur_p->data[0].logit; + float logits_sum = 0; for (size_t i = 0; i < cur_p->size; ++i) { if(cur_p->data[i].logit > max){ max = cur_p->data[i].logit; } logits_sum += cur_p->data[i].logit; } - int32_t mean = logits_sum/cur_p->size; + float mean = (float)logits_sum/cur_p->size; // calculate standard deviation - int32_t acc = 0; + float acc = 0; for(size_t i = 0; i < cur_p->size; ++i){ - acc += (cur_p->data[i].logit - mean) * (cur_p->data[i].logit - mean); + acc += pow(cur_p->data[i].logit - mean, 2); } - int32_t std = sqrt(acc/cur_p->size); - + float std = sqrt((float)acc/cur_p->size); + //apply mask for(size_t i = 0; i < cur_p->size; ++i){ - if(cur_p->data[i].logit < max - (ctx->n * std)) { + if(cur_p->data[i].logit < max - ((float)ctx->n * std)) { cur_p->data[i].logit = -INFINITY; } } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index c0dcb4848..59bde4d41 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -182,6 +182,17 @@ static void test_dry( tester.check(); } +static void test_top_n_sigma(const std::vector & probs, const std::vector & probs_expected, int n) { + sampler_tester tester(probs, probs_expected); + + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_top_n_sigma(n)); + tester.apply(llama_sampler_init_dist (0)); + DUMP(&tester.cur_p); + + tester.check(); +} + static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { sampler_tester tester(n_vocab); @@ -349,6 +360,14 @@ int main(void) { test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {}); test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {}); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3); + + // test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3); + // test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4); + // test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0); + test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f); test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f); From d905a9e9b7339a23f5181dce4f839169e1ecfda2 Mon Sep 17 00:00:00 2001 From: VJHack Date: Mon, 13 Jan 2025 17:51:40 -0600 Subject: [PATCH 06/14] cleaned up pr --- common/sampling.cpp | 2 +- src/llama-sampling.cpp | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 9d58c1680..1c56999a3 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -132,7 +132,7 @@ std::string common_params_sampling::print() const { "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %d, temp = %.3f\n" - "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f,", + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index e91401d75..eeba3aa7e 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -301,7 +301,6 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) cur_p->size = k; } - static uint32_t get_rng_seed(uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { // use system clock if std::random_device is not a true RNG @@ -1677,7 +1676,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t acc += pow(cur_p->data[i].logit - mean, 2); } float std = sqrt((float)acc/cur_p->size); - + //apply mask for(size_t i = 0; i < cur_p->size; ++i){ if(cur_p->data[i].logit < max - ((float)ctx->n * std)) { From a590dcb7f6cd8460284723ed30fa721735c7cdf0 Mon Sep 17 00:00:00 2001 From: VJHack Date: Mon, 13 Jan 2025 19:05:15 -0600 Subject: [PATCH 07/14] format --- .pre-commit-config.yaml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 91d791628..97e78b9a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,8 +9,3 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files -- repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - additional_dependencies: [flake8-no-print] From 0f7501c913bcf57a25956e91a06591a2ac1158bd Mon Sep 17 00:00:00 2001 From: VJHack Date: Mon, 13 Jan 2025 19:05:37 -0600 Subject: [PATCH 08/14] format --- .pre-commit-config.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 97e78b9a5..91d791628 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,3 +9,8 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files +- repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + additional_dependencies: [flake8-no-print] From b29deb83cd6cf1a19bce3610107868c43a9a4eda Mon Sep 17 00:00:00 2001 From: VJHack Date: Mon, 13 Jan 2025 19:08:30 -0600 Subject: [PATCH 09/14] format --- examples/main/README.md | 2 +- src/llama-sampling.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/main/README.md b/examples/main/README.md index 9591b1b0a..f2de4e81b 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -273,7 +273,7 @@ Top-nσ sampling is a text generation method that selects tokens based on a stat Example usage: `--top-nsigma 1` -### Logit Bias +### Logit Bias - `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 11035e72e..9938f0f3d 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1672,7 +1672,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t logits_sum += cur_p->data[i].logit; } float mean = (float)logits_sum/cur_p->size; - + // calculate standard deviation float acc = 0; for(size_t i = 0; i < cur_p->size; ++i){ From f08e6f5bdcb1a01f17be1aec32a4f1b95cf28eef Mon Sep 17 00:00:00 2001 From: VJHack Date: Mon, 13 Jan 2025 20:32:08 -0600 Subject: [PATCH 10/14] removed commented tests --- tests/test-sampling.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 59bde4d41..d2459f91d 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -364,10 +364,6 @@ int main(void) { test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0); test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3); - // test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3); - // test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4); - // test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0); - test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f); test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f); From 6664d4709fe66910ac82b07879ef67189fd31bbb Mon Sep 17 00:00:00 2001 From: VJHack Date: Mon, 13 Jan 2025 22:21:38 -0600 Subject: [PATCH 11/14] cleanup pr and remove explicit floats --- src/llama-sampling.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 9938f0f3d..9826bdd09 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1657,7 +1657,6 @@ struct llama_sampler_top_n_sigma { static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { return "top-n-sigma"; } -#include static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; @@ -1671,18 +1670,18 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t } logits_sum += cur_p->data[i].logit; } - float mean = (float)logits_sum/cur_p->size; + float mean = logits_sum/cur_p->size; // calculate standard deviation float acc = 0; for(size_t i = 0; i < cur_p->size; ++i){ acc += pow(cur_p->data[i].logit - mean, 2); } - float std = sqrt((float)acc/cur_p->size); + float std = sqrt(acc/cur_p->size); //apply mask for(size_t i = 0; i < cur_p->size; ++i){ - if(cur_p->data[i].logit < max - ((float)ctx->n * std)) { + if(cur_p->data[i].logit < max - ctx->n * std) { cur_p->data[i].logit = -INFINITY; } } From c6123e69b00012d3ed95f8d19ecf10e8115de07f Mon Sep 17 00:00:00 2001 From: VJHack Date: Fri, 17 Jan 2025 01:17:40 -0600 Subject: [PATCH 12/14] added top-k sampler to improve performance --- common/sampling.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index fdc493f04..ddface5ed 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -168,8 +168,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co if (params.mirostat == 0) { if(params.top_n_sigma >= 0) { - llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); } else { for (const auto & cnstr : params.samplers) { switch (cnstr) { From 6c1ca58f071a53e545eeddeac4cc78e4d34689a3 Mon Sep 17 00:00:00 2001 From: VJHack Date: Sun, 19 Jan 2025 22:40:54 -0600 Subject: [PATCH 13/14] changed sigma to float --- common/common.h | 2 +- common/sampling.cpp | 2 +- include/llama.h | 2 +- src/llama-sampling.cpp | 6 +++--- tests/test-sampling.cpp | 6 +++--- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/common/common.h b/common/common.h index 1cce234df..9ea5d4d8b 100644 --- a/common/common.h +++ b/common/common.h @@ -134,7 +134,7 @@ struct common_params_sampling { int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - int32_t top_n_sigma = -1; // -1 = disabled + float top_n_sigma = -1.00f;// -1.0 = disabled float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; diff --git a/common/sampling.cpp b/common/sampling.cpp index ddface5ed..0bd682774 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -134,7 +134,7 @@ std::string common_params_sampling::print() const { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" - "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %d, temp = %.3f\n" + "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, diff --git a/include/llama.h b/include/llama.h index acc177231..76dff06c4 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1164,7 +1164,7 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 - LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n); + LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 9826bdd09..876e51d5c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1651,7 +1651,7 @@ struct llama_sampler * llama_sampler_init_penalties( // top-n-sigma struct llama_sampler_top_n_sigma { - const int32_t n; + const float n; }; static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { @@ -1681,7 +1681,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t //apply mask for(size_t i = 0; i < cur_p->size; ++i){ - if(cur_p->data[i].logit < max - ctx->n * std) { + if(cur_p->data[i].logit < max - (ctx->n * std)) { cur_p->data[i].logit = -INFINITY; } } @@ -1706,7 +1706,7 @@ static struct llama_sampler_i llama_sampler_top_n_sigma_i = { /* .free = */ llama_sampler_top_n_sigma_free, }; -struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n) { +struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { return new llama_sampler { /* .iface = */ &llama_sampler_top_n_sigma_i, /* .ctx = */ new llama_sampler_top_n_sigma { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index d2459f91d..9b4f2341c 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -360,9 +360,9 @@ int main(void) { test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {}); test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {}); - test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1); - test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0); - test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.00f); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f); test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f); From a52e023969d317fffc70fc9410292b6a2f942abb Mon Sep 17 00:00:00 2001 From: VJHack Date: Sun, 19 Jan 2025 22:58:42 -0600 Subject: [PATCH 14/14] fixed string format to float --- common/arg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 5e2f9c1f6..27d5c576d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -914,7 +914,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--top-nsigma"}, "N", - string_format("top-n-sigma sampling (default: %d, -1 = disabled)", params.sampling.top_n_sigma), + string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma), [](common_params & params, const std::string & value) { params.sampling.top_n_sigma = std::stof(value); }