From a6c327884532822144a204430d2e721d047577e5 Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Tue, 5 Dec 2023 14:14:39 +0500 Subject: [PATCH] Formatting fixes --- common/common.cpp | 10 +++++----- common/sampling.cpp | 22 +++++++++++----------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 75588e156..b184fea09 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -904,11 +904,11 @@ std::string gpt_random_prompt(std::mt19937 & rng) { // String parsing // -std::string parse_samplers_input(std::string input){ +std::string parse_samplers_input(std::string input) { std::string output = ""; // since samplers names are written multiple ways // make it ready for both system names and input names - std::unordered_map samplers_symbols{ + std::unordered_map samplers_symbols { {"top_k", 'k'}, {"top-k", 'k'}, {"top_p", 'p'}, @@ -927,16 +927,16 @@ std::string parse_samplers_input(std::string input){ }; // expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p" size_t separator = input.find(';'); - while (separator != input.npos){ + while (separator != input.npos) { std::string name = input.substr(0,separator); input = input.substr(separator+1); separator = input.find(';'); - if (samplers_symbols.find(name) != samplers_symbols.end()){ + if (samplers_symbols.find(name) != samplers_symbols.end()) { output += samplers_symbols[name]; } } - if (samplers_symbols.find(input) != samplers_symbols.end()){ + if (samplers_symbols.find(input) != samplers_symbols.end()) { output += samplers_symbols[input]; } return output; diff --git a/common/sampling.cpp b/common/sampling.cpp index 57ead6607..7761ee94a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -101,9 +101,9 @@ std::string llama_sampling_print(const llama_sampling_params & params) { std::string llama_sampling_order_print(const llama_sampling_params & params) { std::string result = "CFG -> Penalties "; - if (params.mirostat == 0){ - for (auto s : params.samplers_sequence){ - switch (s){ + if (params.mirostat == 0) { + for (auto s : params.samplers_sequence) { + switch (s) { case 'k': result += "-> top_k "; break; case 'f': result += "-> tfs_z "; break; case 'y': result += "-> typical_p "; break; @@ -126,15 +126,15 @@ void sampler_queue( size_t & min_keep) { const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; - const float top_p = params.top_p; - const float min_p = params.min_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const std::string samplers_sequence = params.samplers_sequence; + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; + const float top_p = params.top_p; + const float min_p = params.min_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const std::string & samplers_sequence = params.samplers_sequence; - for (auto s : samplers_sequence){ + for (auto s : samplers_sequence) { switch (s){ case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;