Formatting fixes
This commit is contained in:
parent
c879b6d183
commit
a6c3278845
2 changed files with 16 additions and 16 deletions
|
@ -904,11 +904,11 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
|
||||||
// String parsing
|
// String parsing
|
||||||
//
|
//
|
||||||
|
|
||||||
std::string parse_samplers_input(std::string input){
|
std::string parse_samplers_input(std::string input) {
|
||||||
std::string output = "";
|
std::string output = "";
|
||||||
// since samplers names are written multiple ways
|
// since samplers names are written multiple ways
|
||||||
// make it ready for both system names and input names
|
// make it ready for both system names and input names
|
||||||
std::unordered_map<std::string, char> samplers_symbols{
|
std::unordered_map<std::string, char> samplers_symbols {
|
||||||
{"top_k", 'k'},
|
{"top_k", 'k'},
|
||||||
{"top-k", 'k'},
|
{"top-k", 'k'},
|
||||||
{"top_p", 'p'},
|
{"top_p", 'p'},
|
||||||
|
@ -927,16 +927,16 @@ std::string parse_samplers_input(std::string input){
|
||||||
};
|
};
|
||||||
// expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
|
// expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
|
||||||
size_t separator = input.find(';');
|
size_t separator = input.find(';');
|
||||||
while (separator != input.npos){
|
while (separator != input.npos) {
|
||||||
std::string name = input.substr(0,separator);
|
std::string name = input.substr(0,separator);
|
||||||
input = input.substr(separator+1);
|
input = input.substr(separator+1);
|
||||||
separator = input.find(';');
|
separator = input.find(';');
|
||||||
|
|
||||||
if (samplers_symbols.find(name) != samplers_symbols.end()){
|
if (samplers_symbols.find(name) != samplers_symbols.end()) {
|
||||||
output += samplers_symbols[name];
|
output += samplers_symbols[name];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (samplers_symbols.find(input) != samplers_symbols.end()){
|
if (samplers_symbols.find(input) != samplers_symbols.end()) {
|
||||||
output += samplers_symbols[input];
|
output += samplers_symbols[input];
|
||||||
}
|
}
|
||||||
return output;
|
return output;
|
||||||
|
|
|
@ -101,9 +101,9 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
||||||
|
|
||||||
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
||||||
std::string result = "CFG -> Penalties ";
|
std::string result = "CFG -> Penalties ";
|
||||||
if (params.mirostat == 0){
|
if (params.mirostat == 0) {
|
||||||
for (auto s : params.samplers_sequence){
|
for (auto s : params.samplers_sequence) {
|
||||||
switch (s){
|
switch (s) {
|
||||||
case 'k': result += "-> top_k "; break;
|
case 'k': result += "-> top_k "; break;
|
||||||
case 'f': result += "-> tfs_z "; break;
|
case 'f': result += "-> tfs_z "; break;
|
||||||
case 'y': result += "-> typical_p "; break;
|
case 'y': result += "-> typical_p "; break;
|
||||||
|
@ -126,15 +126,15 @@ void sampler_queue(
|
||||||
size_t & min_keep) {
|
size_t & min_keep) {
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
||||||
const float top_p = params.top_p;
|
const float top_p = params.top_p;
|
||||||
const float min_p = params.min_p;
|
const float min_p = params.min_p;
|
||||||
const float tfs_z = params.tfs_z;
|
const float tfs_z = params.tfs_z;
|
||||||
const float typical_p = params.typical_p;
|
const float typical_p = params.typical_p;
|
||||||
const std::string samplers_sequence = params.samplers_sequence;
|
const std::string & samplers_sequence = params.samplers_sequence;
|
||||||
|
|
||||||
for (auto s : samplers_sequence){
|
for (auto s : samplers_sequence) {
|
||||||
switch (s){
|
switch (s){
|
||||||
case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
|
case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
|
||||||
case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
|
case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue