From fa651909bbad29ae0d5a1f1ad57e4b8026e46807 Mon Sep 17 00:00:00 2001 From: wbpxre150 Date: Fri, 14 Apr 2023 01:24:32 +0800 Subject: [PATCH] refector code into function. --- examples/gpt4all.sh | 2 +- examples/main/main.cpp | 94 ++++++++++++++++++++++-------------------- 2 files changed, 50 insertions(+), 46 deletions(-) diff --git a/examples/gpt4all.sh b/examples/gpt4all.sh index b3604f200..4b1c15f0d 100755 --- a/examples/gpt4all.sh +++ b/examples/gpt4all.sh @@ -11,4 +11,4 @@ cd .. --model ./models/gpt4all.bin \ --batch_size 8 --ctx_size 2048 \ --repeat_last_n 64 --repeat_penalty 1.3 \ - --n_predict 2048 --temp 0.1 --top_k 40 --top_p 0.95 --mlock + --n_predict 2048 --temp 0.1 --top_k 50 --top_p 0.99 --mlock diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c1d6fbcd8..b70b034a7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -47,6 +47,53 @@ void sigint_handler(int signo) { } #endif +void command(std::string buffer, gpt_params params, const int n_ctx ) { + // check buffer's first 3 chars equal '???' to enter command mode. + if (strncmp(buffer.c_str(), "???", 3) == 0 && buffer.length() > 3) { + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + std::istringstream command(buffer); + int j = 0; std::string test, arg, cmd; + while (command>>test) { + j++; + if ( j == 2 ) { + arg = test; + } + if ( j == 3 ) { + cmd = test; + } + } + if (arg == "-n" || arg == "--n_predict") { + params.n_predict = std::stoi(cmd); + } else if (arg == "--top_k") { + params.top_k = std::stoi(cmd); + } else if (arg == "-c" || arg == "--ctx_size") { + params.n_ctx = std::stoi(cmd); + } else if (arg == "--top_p") { + params.top_p = std::stof(cmd); + } else if (arg == "--temp") { + params.temp = std::stof(cmd); + } else if (arg == "--repeat_last_n") { + params.repeat_last_n = std::stoi(cmd); + } else if (arg == "--repeat_penalty") { + params.repeat_penalty = std::stof(cmd); + } else if (arg == "-b" || arg == "--batch_size") { + params.n_batch = std::stoi(cmd); + params.n_batch = std::min(512, params.n_batch); + } else if (arg == "-r" || arg == "--reverse-prompt") { + params.antiprompt.push_back(cmd); + } else if (arg == "--keep") { + params.n_keep = std::stoi(cmd); + } else if (arg == "stats") { + llama_print_timings(ctx); + } else { + printf("invalid argument parsed: %s\n please use -n, --top_k, -c, --top_p, --temp, --repeat_last_n, --repeat_penalty, -b, -r, --keep or stats", arg.c_str()); + } + printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", + params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + } +} + int main(int argc, char ** argv) { gpt_params params; params.model = "models/llama-7B/ggml-model.bin"; @@ -439,51 +486,8 @@ int main(int argc, char ** argv) { } n_remain -= line_inp.size(); - - // check buffer's first 3 chars equal '???' to enter command mode. - if (strncmp(buffer.c_str(), "???", 3) == 0 && buffer.length() > 3) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - std::istringstream command(buffer); - int j = 0; std::string test, arg, cmd; - while (command>>test) { - j++; - if ( j == 2 ) { - arg = test; - } - if ( j == 3 ) { - cmd = test; - } - } - if (arg == "-n" || arg == "--n_predict") { - params.n_predict = std::stoi(cmd); - } else if (arg == "--top_k") { - params.top_k = std::stoi(cmd); - } else if (arg == "-c" || arg == "--ctx_size") { - params.n_ctx = std::stoi(cmd); - } else if (arg == "--top_p") { - params.top_p = std::stof(cmd); - } else if (arg == "--temp") { - params.temp = std::stof(cmd); - } else if (arg == "--repeat_last_n") { - params.repeat_last_n = std::stoi(cmd); - } else if (arg == "--repeat_penalty") { - params.repeat_penalty = std::stof(cmd); - } else if (arg == "-b" || arg == "--batch_size") { - params.n_batch = std::stoi(cmd); - params.n_batch = std::min(512, params.n_batch); - } else if (arg == "-r" || arg == "--reverse-prompt") { - params.antiprompt.push_back(cmd); - } else if (arg == "--keep") { - params.n_keep = std::stoi(cmd); - } else if (arg == "stats") { - llama_print_timings(ctx); - } else { - printf("invalid argument parsed: %s\n please use -n, --top_k, -c, --top_p, --temp, --repeat_last_n, --repeat_penalty, -b, -r, --keep or stats", arg.c_str()); - } - printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", - params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); - printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); - } + + command(buffer, params, n_ctx); } input_noecho = true; // do not echo this again