Samplers sequence order w parameter
This commit is contained in:
parent
8d6d9f033b
commit
d4dc3d26fc
4 changed files with 113 additions and 36 deletions
|
@ -280,6 +280,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
params.yarn_beta_slow = std::stof(argv[i]);
|
||||
} else if (arg == "--memory-f32") {
|
||||
params.memory_f16 = false;
|
||||
} else if (arg == "--sampling-seq") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.samplers_sequence = argv[i];
|
||||
} else if (arg == "--top-p") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
|
|
@ -99,6 +99,42 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
|||
return std::string(result);
|
||||
}
|
||||
|
||||
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
||||
std::string result = "CFG -> Penalties ";
|
||||
if (params.mirostat == 0){
|
||||
for (auto s : params.samplers_sequence){
|
||||
switch (s){
|
||||
case 'k':{
|
||||
result += "-> top_k ";
|
||||
break;
|
||||
}
|
||||
case 'f':{
|
||||
result += "-> tfs_z ";
|
||||
break;
|
||||
}
|
||||
case 'y':{
|
||||
result += "-> typical_p ";
|
||||
break;
|
||||
}
|
||||
case 'p':{
|
||||
result += "-> top_p ";
|
||||
break;
|
||||
}
|
||||
case 'm':{
|
||||
result += "-> min_p ";
|
||||
break;
|
||||
}
|
||||
case 't':{
|
||||
result += "-> temp ";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else result += "-> mirostat ";
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sampling_sample(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
|
@ -108,20 +144,21 @@ llama_token llama_sampling_sample(
|
|||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
|
||||
const float temp = params.temp;
|
||||
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
const float min_p = params.min_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
const float typical_p = params.typical_p;
|
||||
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
||||
const float penalty_repeat = params.penalty_repeat;
|
||||
const float penalty_freq = params.penalty_freq;
|
||||
const float penalty_present = params.penalty_present;
|
||||
const int mirostat = params.mirostat;
|
||||
const float mirostat_tau = params.mirostat_tau;
|
||||
const float mirostat_eta = params.mirostat_eta;
|
||||
const bool penalize_nl = params.penalize_nl;
|
||||
const float temp = params.temp;
|
||||
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
const float min_p = params.min_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
const float typical_p = params.typical_p;
|
||||
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
||||
const float penalty_repeat = params.penalty_repeat;
|
||||
const float penalty_freq = params.penalty_freq;
|
||||
const float penalty_present = params.penalty_present;
|
||||
const int mirostat = params.mirostat;
|
||||
const float mirostat_tau = params.mirostat_tau;
|
||||
const float mirostat_eta = params.mirostat_eta;
|
||||
const bool penalize_nl = params.penalize_nl;
|
||||
const std::string samplers_sequence = params.samplers_sequence;
|
||||
|
||||
auto & prev = ctx_sampling->prev;
|
||||
auto & cur = ctx_sampling->cur;
|
||||
|
@ -188,12 +225,41 @@ llama_token llama_sampling_sample(
|
|||
// temperature sampling
|
||||
size_t min_keep = std::max(1, params.n_probs);
|
||||
|
||||
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
||||
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
||||
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
||||
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
|
||||
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
|
||||
llama_sample_temp (ctx_main, &cur_p, temp);
|
||||
// llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
||||
// llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
||||
// llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
||||
// llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
|
||||
// llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
|
||||
// llama_sample_temp (ctx_main, &cur_p, temp);
|
||||
|
||||
for (auto s : samplers_sequence){
|
||||
switch (s){
|
||||
case 'k':{
|
||||
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
||||
break;
|
||||
}
|
||||
case 'f':{
|
||||
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
||||
break;
|
||||
}
|
||||
case 'y':{
|
||||
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
||||
break;
|
||||
}
|
||||
case 'p':{
|
||||
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
|
||||
break;
|
||||
}
|
||||
case 'm':{
|
||||
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
|
||||
break;
|
||||
}
|
||||
case 't':{
|
||||
llama_sample_temp (ctx_main, &cur_p, temp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
id = llama_sample_token(ctx_main, &cur_p);
|
||||
|
||||
|
|
|
@ -10,22 +10,23 @@
|
|||
|
||||
// sampling parameters
|
||||
typedef struct llama_sampling_params {
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
float tfs_z = 1.00f; // 1.0 = disabled
|
||||
float typical_p = 1.00f; // 1.0 = disabled
|
||||
float temp = 0.80f; // 1.0 = disabled
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.10f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool penalize_nl = true; // consider newlines as a repeatable token
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
float tfs_z = 1.00f; // 1.0 = disabled
|
||||
float typical_p = 1.00f; // 1.0 = disabled
|
||||
float temp = 0.80f; // 1.0 = disabled
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.10f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool penalize_nl = true; // consider newlines as a repeatable token
|
||||
std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
|
||||
|
@ -80,6 +81,9 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
|
|||
// Print sampling parameters into a string
|
||||
std::string llama_sampling_print(const llama_sampling_params & params);
|
||||
|
||||
// Print sampling order into a string
|
||||
std::string llama_sampling_order_print(const llama_sampling_params & params);
|
||||
|
||||
// this is a common sampling function used across the examples for convenience
|
||||
// it can serve as a starting point for implementing your own sampling function
|
||||
// Note: When using multiple sequences, it is the caller's responsibility to call
|
||||
|
|
|
@ -437,6 +437,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
||||
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
|
||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
LOG_TEE("\n\n");
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue