diff --git a/common/common.cpp b/common/common.cpp index 1dcc235ea..b148e3a78 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -280,6 +280,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.yarn_beta_slow = std::stof(argv[i]); } else if (arg == "--memory-f32") { params.memory_f16 = false; + } else if (arg == "--sampling-seq") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.samplers_sequence = argv[i]; } else if (arg == "--top-p") { if (++i >= argc) { invalid_param = true; diff --git a/common/sampling.cpp b/common/sampling.cpp index 1317024c2..891d05736 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -99,6 +99,42 @@ std::string llama_sampling_print(const llama_sampling_params & params) { return std::string(result); } +std::string llama_sampling_order_print(const llama_sampling_params & params) { + std::string result = "CFG -> Penalties "; + if (params.mirostat == 0){ + for (auto s : params.samplers_sequence){ + switch (s){ + case 'k':{ + result += "-> top_k "; + break; + } + case 'f':{ + result += "-> tfs_z "; + break; + } + case 'y':{ + result += "-> typical_p "; + break; + } + case 'p':{ + result += "-> top_p "; + break; + } + case 'm':{ + result += "-> min_p "; + break; + } + case 't':{ + result += "-> temp "; + break; + } + } + } + } else result += "-> mirostat "; + + return result; +} + llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, @@ -108,20 +144,21 @@ llama_token llama_sampling_sample( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; - const float top_p = params.top_p; - const float min_p = params.min_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; - const float penalty_repeat = params.penalty_repeat; - const float penalty_freq = params.penalty_freq; - const float penalty_present = params.penalty_present; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - const bool penalize_nl = params.penalize_nl; + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; + const float top_p = params.top_p; + const float min_p = params.min_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; + const float penalty_repeat = params.penalty_repeat; + const float penalty_freq = params.penalty_freq; + const float penalty_present = params.penalty_present; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + const bool penalize_nl = params.penalize_nl; + const std::string samplers_sequence = params.samplers_sequence; auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -188,12 +225,41 @@ llama_token llama_sampling_sample( // temperature sampling size_t min_keep = std::max(1, params.n_probs); - llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); - llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); - llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); - llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); - llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); - llama_sample_temp (ctx_main, &cur_p, temp); + // llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); + // llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); + // llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); + // llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); + // llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); + // llama_sample_temp (ctx_main, &cur_p, temp); + + for (auto s : samplers_sequence){ + switch (s){ + case 'k':{ + llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); + break; + } + case 'f':{ + llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); + break; + } + case 'y':{ + llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); + break; + } + case 'p':{ + llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); + break; + } + case 'm':{ + llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); + break; + } + case 't':{ + llama_sample_temp (ctx_main, &cur_p, temp); + break; + } + } + } id = llama_sample_token(ctx_main, &cur_p); diff --git a/common/sampling.h b/common/sampling.h index 7c9b8dcf2..c4191a6e5 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -10,22 +10,23 @@ // sampling parameters typedef struct llama_sampling_params { - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // 1.0 = disabled - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.10f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = true; // consider newlines as a repeatable token + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + float temp = 0.80f; // 1.0 = disabled + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.10f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = true; // consider newlines as a repeatable token + std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp std::string grammar; // optional BNF-like grammar to constrain sampling @@ -80,6 +81,9 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama // Print sampling parameters into a string std::string llama_sampling_print(const llama_sampling_params & params); +// Print sampling order into a string +std::string llama_sampling_order_print(const llama_sampling_params & params); + // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c5cdfbf21..c096f110b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -437,6 +437,7 @@ int main(int argc, char ** argv) { } } LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); + LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n");