diff --git a/common/sampling.cpp b/common/sampling.cpp index 71f92aa65..03c70fe95 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,5 +1,7 @@ #include "sampling.h" +#include + struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -101,40 +103,96 @@ std::string llama_sampling_print(const llama_sampling_params & params) { std::string llama_sampling_order_print(const llama_sampling_params & params) { std::string result = "CFG -> Penalties "; + + std::unordered_map samplers_map_display { + {'k', "-> top_k "}, + {'f', "-> tfs_z "}, + {'y', "-> typical_p "}, + {'p', "-> top_p "}, + {'m', "-> min_p "}, + {'t', "-> temp "} + }; + if (params.mirostat == 0){ for (auto s : params.samplers_sequence){ - switch (s){ - case 'k':{ - result += "-> top_k "; - break; - } - case 'f':{ - result += "-> tfs_z "; - break; - } - case 'y':{ - result += "-> typical_p "; - break; - } - case 'p':{ - result += "-> top_p "; - break; - } - case 'm':{ - result += "-> min_p "; - break; - } - case 't':{ - result += "-> temp "; - break; - } - } + result += samplers_map_display[s]; } } else result += "-> mirostat "; return result; } +void sample_top_k( + const llama_sampling_params & params, + struct llama_context * ctx_main, + llama_token_data_array & cur_p, + size_t & min_keep){ + + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; + llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); +} + +void sample_top_p( + const llama_sampling_params & params, + struct llama_context * ctx_main, + llama_token_data_array & cur_p, + size_t & min_keep){ + + const float top_p = params.top_p; + llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); +} + +void sample_tfs_z( + const llama_sampling_params & params, + struct llama_context * ctx_main, + llama_token_data_array & cur_p, + size_t & min_keep){ + + const float tfs_z = params.tfs_z; + llama_sample_tail_free (ctx_main, &cur_p, tfs_z, min_keep); +} + +void sample_typical_p( + const llama_sampling_params & params, + struct llama_context * ctx_main, + llama_token_data_array & cur_p, + size_t & min_keep){ + + const float typical_p = params.typical_p; + llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); +} + +void sample_min_p( + const llama_sampling_params & params, + struct llama_context * ctx_main, + llama_token_data_array & cur_p, + size_t & min_keep){ + + const float min_p = params.min_p; + llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); +} + +void sample_temp( + const llama_sampling_params & params, + struct llama_context * ctx_main, + llama_token_data_array & cur_p, + size_t & min_keep){ + + const float temp = params.temp; + llama_sample_temp (ctx_main, &cur_p, temp); +} + +std::unordered_map> samplers_map +{ + {'k', sample_top_k}, + {'f', sample_tfs_z}, + {'y', sample_typical_p}, + {'p', sample_top_p}, + {'m', sample_min_p}, + {'t', sample_temp} +}; + llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, @@ -145,11 +203,6 @@ llama_token llama_sampling_sample( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; - const float top_p = params.top_p; - const float min_p = params.min_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; @@ -226,32 +279,7 @@ llama_token llama_sampling_sample( size_t min_keep = std::max(1, params.n_probs); for (auto s : samplers_sequence){ - switch (s){ - case 'k':{ - llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); - break; - } - case 'f':{ - llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); - break; - } - case 'y':{ - llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); - break; - } - case 'p':{ - llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); - break; - } - case 'm':{ - llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); - break; - } - case 't':{ - llama_sample_temp (ctx_main, &cur_p, temp); - break; - } - } + samplers_map[s](params, ctx_main, cur_p, min_keep); } id = llama_sample_token(ctx_main, &cur_p); diff --git a/common/sampling.h b/common/sampling.h index 9872b8b38..0ce10b232 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -84,6 +84,42 @@ std::string llama_sampling_print(const llama_sampling_params & params); // Print sampling order into a string std::string llama_sampling_order_print(const llama_sampling_params & params); +void sample_top_k( + const llama_sampling_params & params, + struct llama_context * ctx_main, + llama_token_data_array & cur_p, + size_t & min_keep); + +void sample_top_p( + const llama_sampling_params & params, + struct llama_context * ctx_main, + llama_token_data_array & cur_p, + size_t & min_keep); + +void sample_tfs_z( + const llama_sampling_params & params, + struct llama_context * ctx_main, + llama_token_data_array & cur_p, + size_t & min_keep); + +void sample_typical_p( + const llama_sampling_params & params, + struct llama_context * ctx_main, + llama_token_data_array & cur_p, + size_t & min_keep); + +void sample_min_p( + const llama_sampling_params & params, + struct llama_context * ctx_main, + llama_token_data_array & cur_p, + size_t & min_keep); + +void sample_temp( + const llama_sampling_params & params, + struct llama_context * ctx_main, + llama_token_data_array & cur_p, + size_t & min_keep); + // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call