diff --git a/common/sampling.cpp b/common/sampling.cpp index 03c70fe95..e47dd1f79 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,7 +1,5 @@ #include "sampling.h" -#include - struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -103,96 +101,89 @@ std::string llama_sampling_print(const llama_sampling_params & params) { std::string llama_sampling_order_print(const llama_sampling_params & params) { std::string result = "CFG -> Penalties "; - - std::unordered_map samplers_map_display { - {'k', "-> top_k "}, - {'f', "-> tfs_z "}, - {'y', "-> typical_p "}, - {'p', "-> top_p "}, - {'m', "-> min_p "}, - {'t', "-> temp "} - }; - if (params.mirostat == 0){ for (auto s : params.samplers_sequence){ - result += samplers_map_display[s]; + switch (s){ + case 'k':{ + result += "-> top_k "; + break; + } + case 'f':{ + result += "-> tfs_z "; + break; + } + case 'y':{ + result += "-> typical_p "; + break; + } + case 'p':{ + result += "-> top_p "; + break; + } + case 'm':{ + result += "-> min_p "; + break; + } + case 't':{ + result += "-> temp "; + break; + } + default: break; + } } } else result += "-> mirostat "; return result; } -void sample_top_k( - const llama_sampling_params & params, - struct llama_context * ctx_main, - llama_token_data_array & cur_p, - size_t & min_keep){ +// no reasons to expose this function in header +void sampler_queue( + struct llama_context * ctx_main, + const llama_sampling_params & params, + llama_token_data_array & cur_p, + size_t & min_keep) { + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; + const float top_p = params.top_p; + const float min_p = params.min_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const std::string samplers_sequence = params.samplers_sequence; + + for (auto s : samplers_sequence){ + switch (s){ + case 'k':{ + llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); + break; + } + case 'f':{ + llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); + break; + } + case 'y':{ + llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); + break; + } + case 'p':{ + llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); + break; + } + case 'm':{ + llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); + break; + } + case 't':{ + llama_sample_temp (ctx_main, &cur_p, temp); + break; + } + default: break; + } + } - const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); - const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; - llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); } -void sample_top_p( - const llama_sampling_params & params, - struct llama_context * ctx_main, - llama_token_data_array & cur_p, - size_t & min_keep){ - - const float top_p = params.top_p; - llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); -} - -void sample_tfs_z( - const llama_sampling_params & params, - struct llama_context * ctx_main, - llama_token_data_array & cur_p, - size_t & min_keep){ - - const float tfs_z = params.tfs_z; - llama_sample_tail_free (ctx_main, &cur_p, tfs_z, min_keep); -} - -void sample_typical_p( - const llama_sampling_params & params, - struct llama_context * ctx_main, - llama_token_data_array & cur_p, - size_t & min_keep){ - - const float typical_p = params.typical_p; - llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); -} - -void sample_min_p( - const llama_sampling_params & params, - struct llama_context * ctx_main, - llama_token_data_array & cur_p, - size_t & min_keep){ - - const float min_p = params.min_p; - llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); -} - -void sample_temp( - const llama_sampling_params & params, - struct llama_context * ctx_main, - llama_token_data_array & cur_p, - size_t & min_keep){ - - const float temp = params.temp; - llama_sample_temp (ctx_main, &cur_p, temp); -} - -std::unordered_map> samplers_map -{ - {'k', sample_top_k}, - {'f', sample_tfs_z}, - {'y', sample_typical_p}, - {'p', sample_top_p}, - {'m', sample_min_p}, - {'t', sample_temp} -}; - llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, @@ -211,7 +202,6 @@ llama_token llama_sampling_sample( const float mirostat_tau = params.mirostat_tau; const float mirostat_eta = params.mirostat_eta; const bool penalize_nl = params.penalize_nl; - const std::string samplers_sequence = params.samplers_sequence; auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -278,9 +268,7 @@ llama_token llama_sampling_sample( // temperature sampling size_t min_keep = std::max(1, params.n_probs); - for (auto s : samplers_sequence){ - samplers_map[s](params, ctx_main, cur_p, min_keep); - } + sampler_queue(ctx_main, params, cur_p, min_keep); id = llama_sample_token(ctx_main, &cur_p); diff --git a/common/sampling.h b/common/sampling.h index 0ce10b232..9872b8b38 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -84,42 +84,6 @@ std::string llama_sampling_print(const llama_sampling_params & params); // Print sampling order into a string std::string llama_sampling_order_print(const llama_sampling_params & params); -void sample_top_k( - const llama_sampling_params & params, - struct llama_context * ctx_main, - llama_token_data_array & cur_p, - size_t & min_keep); - -void sample_top_p( - const llama_sampling_params & params, - struct llama_context * ctx_main, - llama_token_data_array & cur_p, - size_t & min_keep); - -void sample_tfs_z( - const llama_sampling_params & params, - struct llama_context * ctx_main, - llama_token_data_array & cur_p, - size_t & min_keep); - -void sample_typical_p( - const llama_sampling_params & params, - struct llama_context * ctx_main, - llama_token_data_array & cur_p, - size_t & min_keep); - -void sample_min_p( - const llama_sampling_params & params, - struct llama_context * ctx_main, - llama_token_data_array & cur_p, - size_t & min_keep); - -void sample_temp( - const llama_sampling_params & params, - struct llama_context * ctx_main, - llama_token_data_array & cur_p, - size_t & min_keep); - // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call