Rewrote with unordered_map
This commit is contained in:
parent
d363df3444
commit
bd08c8fab3
2 changed files with 121 additions and 57 deletions
|
@ -1,5 +1,7 @@
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
||||||
struct llama_sampling_context * result = new llama_sampling_context();
|
struct llama_sampling_context * result = new llama_sampling_context();
|
||||||
|
|
||||||
|
@ -101,40 +103,96 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
||||||
|
|
||||||
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
||||||
std::string result = "CFG -> Penalties ";
|
std::string result = "CFG -> Penalties ";
|
||||||
|
|
||||||
|
std::unordered_map<char, std::string> samplers_map_display {
|
||||||
|
{'k', "-> top_k "},
|
||||||
|
{'f', "-> tfs_z "},
|
||||||
|
{'y', "-> typical_p "},
|
||||||
|
{'p', "-> top_p "},
|
||||||
|
{'m', "-> min_p "},
|
||||||
|
{'t', "-> temp "}
|
||||||
|
};
|
||||||
|
|
||||||
if (params.mirostat == 0){
|
if (params.mirostat == 0){
|
||||||
for (auto s : params.samplers_sequence){
|
for (auto s : params.samplers_sequence){
|
||||||
switch (s){
|
result += samplers_map_display[s];
|
||||||
case 'k':{
|
|
||||||
result += "-> top_k ";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'f':{
|
|
||||||
result += "-> tfs_z ";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'y':{
|
|
||||||
result += "-> typical_p ";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'p':{
|
|
||||||
result += "-> top_p ";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'm':{
|
|
||||||
result += "-> min_p ";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 't':{
|
|
||||||
result += "-> temp ";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else result += "-> mirostat ";
|
} else result += "-> mirostat ";
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void sample_top_k(
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep){
|
||||||
|
|
||||||
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
||||||
|
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
||||||
|
}
|
||||||
|
|
||||||
|
void sample_top_p(
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep){
|
||||||
|
|
||||||
|
const float top_p = params.top_p;
|
||||||
|
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
|
||||||
|
}
|
||||||
|
|
||||||
|
void sample_tfs_z(
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep){
|
||||||
|
|
||||||
|
const float tfs_z = params.tfs_z;
|
||||||
|
llama_sample_tail_free (ctx_main, &cur_p, tfs_z, min_keep);
|
||||||
|
}
|
||||||
|
|
||||||
|
void sample_typical_p(
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep){
|
||||||
|
|
||||||
|
const float typical_p = params.typical_p;
|
||||||
|
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
||||||
|
}
|
||||||
|
|
||||||
|
void sample_min_p(
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep){
|
||||||
|
|
||||||
|
const float min_p = params.min_p;
|
||||||
|
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
|
||||||
|
}
|
||||||
|
|
||||||
|
void sample_temp(
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep){
|
||||||
|
|
||||||
|
const float temp = params.temp;
|
||||||
|
llama_sample_temp (ctx_main, &cur_p, temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<char, std::function<void(const llama_sampling_params &, struct llama_context *, llama_token_data_array&, size_t&)>> samplers_map
|
||||||
|
{
|
||||||
|
{'k', sample_top_k},
|
||||||
|
{'f', sample_tfs_z},
|
||||||
|
{'y', sample_typical_p},
|
||||||
|
{'p', sample_top_p},
|
||||||
|
{'m', sample_min_p},
|
||||||
|
{'t', sample_temp}
|
||||||
|
};
|
||||||
|
|
||||||
llama_token llama_sampling_sample(
|
llama_token llama_sampling_sample(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
|
@ -145,11 +203,6 @@ llama_token llama_sampling_sample(
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
|
||||||
const float top_p = params.top_p;
|
|
||||||
const float min_p = params.min_p;
|
|
||||||
const float tfs_z = params.tfs_z;
|
|
||||||
const float typical_p = params.typical_p;
|
|
||||||
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
||||||
const float penalty_repeat = params.penalty_repeat;
|
const float penalty_repeat = params.penalty_repeat;
|
||||||
const float penalty_freq = params.penalty_freq;
|
const float penalty_freq = params.penalty_freq;
|
||||||
|
@ -226,32 +279,7 @@ llama_token llama_sampling_sample(
|
||||||
size_t min_keep = std::max(1, params.n_probs);
|
size_t min_keep = std::max(1, params.n_probs);
|
||||||
|
|
||||||
for (auto s : samplers_sequence){
|
for (auto s : samplers_sequence){
|
||||||
switch (s){
|
samplers_map[s](params, ctx_main, cur_p, min_keep);
|
||||||
case 'k':{
|
|
||||||
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'f':{
|
|
||||||
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'y':{
|
|
||||||
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'p':{
|
|
||||||
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'm':{
|
|
||||||
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 't':{
|
|
||||||
llama_sample_temp (ctx_main, &cur_p, temp);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
id = llama_sample_token(ctx_main, &cur_p);
|
id = llama_sample_token(ctx_main, &cur_p);
|
||||||
|
|
|
@ -84,6 +84,42 @@ std::string llama_sampling_print(const llama_sampling_params & params);
|
||||||
// Print sampling order into a string
|
// Print sampling order into a string
|
||||||
std::string llama_sampling_order_print(const llama_sampling_params & params);
|
std::string llama_sampling_order_print(const llama_sampling_params & params);
|
||||||
|
|
||||||
|
void sample_top_k(
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep);
|
||||||
|
|
||||||
|
void sample_top_p(
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep);
|
||||||
|
|
||||||
|
void sample_tfs_z(
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep);
|
||||||
|
|
||||||
|
void sample_typical_p(
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep);
|
||||||
|
|
||||||
|
void sample_min_p(
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep);
|
||||||
|
|
||||||
|
void sample_temp(
|
||||||
|
const llama_sampling_params & params,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token_data_array & cur_p,
|
||||||
|
size_t & min_keep);
|
||||||
|
|
||||||
// this is a common sampling function used across the examples for convenience
|
// this is a common sampling function used across the examples for convenience
|
||||||
// it can serve as a starting point for implementing your own sampling function
|
// it can serve as a starting point for implementing your own sampling function
|
||||||
// Note: When using multiple sequences, it is the caller's responsibility to call
|
// Note: When using multiple sequences, it is the caller's responsibility to call
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue