From 0d903962e567bf3a56093e9156d12404208e6ba3 Mon Sep 17 00:00:00 2001 From: wbpxre150 Date: Thu, 13 Apr 2023 04:58:35 +0800 Subject: [PATCH 1/8] add command line to interactive mode. You can specify diffrent values for everything in params during rumtime. --- examples/main/main.cpp | 50 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 718bb6029..fcf1af5b0 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) @@ -167,7 +168,7 @@ int main(int argc, char ** argv) { // in instruct mode, we inject a prefix and a suffix to each input by the user if (params.instruct) { params.interactive_start = true; - params.antiprompt.push_back("### Instruction:\n\n"); + params.antiprompt.push_back("###"); } // enable interactive mode if reverse prompt or interactive start is specified @@ -438,8 +439,53 @@ int main(int argc, char ** argv) { } n_remain -= line_inp.size(); - } + // check buffer's first 3 chars equal '???' to enter command mode. + if (strncmp(buffer.c_str(), "???", 3) == 0 && buffer.length() > 3) { + std::istringstream command(buffer); + int j = 0; std::string test, arg, cmd; + while (command>>test) { + j++; + if ( j == 2 ) { + arg = test; + } + if ( j == 3 ) { + cmd = test; + } + } + if (arg == "-n" || arg == "--n_predict") { + params.n_predict = std::stoi(cmd); + } else if (arg == "--top_k") { + params.top_k = std::stoi(cmd); + } else if (arg == "-c" || arg == "--ctx_size") { + params.n_ctx = std::stoi(cmd); + } else if (arg == "--top_p") { + params.top_p = std::stof(cmd); + } else if (arg == "--temp") { + params.temp = std::stof(cmd); + } else if (arg == "--repeat_last_n") { + params.repeat_last_n = std::stoi(cmd); + } else if (arg == "--repeat_penalty") { + params.repeat_penalty = std::stof(cmd); + } else if (arg == "-b" || arg == "--batch_size") { + params.n_batch = std::stoi(cmd); + params.n_batch = std::min(512, params.n_batch); + } else if (arg == "-r" || arg == "--reverse-prompt") { + params.antiprompt.push_back(cmd); + } else if (arg == "--keep") { + params.n_keep = std::stoi(cmd); + } else if (arg == "stats") { + llama_print_timings(ctx); + } else { + printf("invalid argument parsed: %s\n", arg.c_str()); + } + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", + params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + } + } + input_noecho = true; // do not echo this again } From f45dec3c2a3770888c84614c08b01da49d18c5dd Mon Sep 17 00:00:00 2001 From: wbpxre150 Date: Thu, 13 Apr 2023 05:32:12 +0800 Subject: [PATCH 2/8] fixes --- examples/main/main.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index fcf1af5b0..c1d6fbcd8 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -442,6 +442,7 @@ int main(int argc, char ** argv) { // check buffer's first 3 chars equal '???' to enter command mode. if (strncmp(buffer.c_str(), "???", 3) == 0 && buffer.length() > 3) { + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); std::istringstream command(buffer); int j = 0; std::string test, arg, cmd; while (command>>test) { @@ -477,9 +478,8 @@ int main(int argc, char ** argv) { } else if (arg == "stats") { llama_print_timings(ctx); } else { - printf("invalid argument parsed: %s\n", arg.c_str()); + printf("invalid argument parsed: %s\n please use -n, --top_k, -c, --top_p, --temp, --repeat_last_n, --repeat_penalty, -b, -r, --keep or stats", arg.c_str()); } - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); From a1b4e48ba21642f8759bb77ad713a73001363721 Mon Sep 17 00:00:00 2001 From: wbpxre150 Date: Thu, 13 Apr 2023 05:46:56 +0800 Subject: [PATCH 3/8] fixO --- examples/gpt4all.sh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/gpt4all.sh b/examples/gpt4all.sh index d974f95a9..b3604f200 100755 --- a/examples/gpt4all.sh +++ b/examples/gpt4all.sh @@ -7,9 +7,8 @@ cd `dirname $0` cd .. -./main --color --instruct --threads 4 \ - --model ./models/gpt4all-7B/gpt4all-lora-quantized.bin \ - --file ./prompts/alpaca.txt \ +./main --color --instruct --threads 6 \ + --model ./models/gpt4all.bin \ --batch_size 8 --ctx_size 2048 \ --repeat_last_n 64 --repeat_penalty 1.3 \ - --n_predict 128 --temp 0.1 --top_k 40 --top_p 0.95 + --n_predict 2048 --temp 0.1 --top_k 40 --top_p 0.95 --mlock From fa651909bbad29ae0d5a1f1ad57e4b8026e46807 Mon Sep 17 00:00:00 2001 From: wbpxre150 Date: Fri, 14 Apr 2023 01:24:32 +0800 Subject: [PATCH 4/8] refector code into function. --- examples/gpt4all.sh | 2 +- examples/main/main.cpp | 94 ++++++++++++++++++++++-------------------- 2 files changed, 50 insertions(+), 46 deletions(-) diff --git a/examples/gpt4all.sh b/examples/gpt4all.sh index b3604f200..4b1c15f0d 100755 --- a/examples/gpt4all.sh +++ b/examples/gpt4all.sh @@ -11,4 +11,4 @@ cd .. --model ./models/gpt4all.bin \ --batch_size 8 --ctx_size 2048 \ --repeat_last_n 64 --repeat_penalty 1.3 \ - --n_predict 2048 --temp 0.1 --top_k 40 --top_p 0.95 --mlock + --n_predict 2048 --temp 0.1 --top_k 50 --top_p 0.99 --mlock diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c1d6fbcd8..b70b034a7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -47,6 +47,53 @@ void sigint_handler(int signo) { } #endif +void command(std::string buffer, gpt_params params, const int n_ctx ) { + // check buffer's first 3 chars equal '???' to enter command mode. + if (strncmp(buffer.c_str(), "???", 3) == 0 && buffer.length() > 3) { + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + std::istringstream command(buffer); + int j = 0; std::string test, arg, cmd; + while (command>>test) { + j++; + if ( j == 2 ) { + arg = test; + } + if ( j == 3 ) { + cmd = test; + } + } + if (arg == "-n" || arg == "--n_predict") { + params.n_predict = std::stoi(cmd); + } else if (arg == "--top_k") { + params.top_k = std::stoi(cmd); + } else if (arg == "-c" || arg == "--ctx_size") { + params.n_ctx = std::stoi(cmd); + } else if (arg == "--top_p") { + params.top_p = std::stof(cmd); + } else if (arg == "--temp") { + params.temp = std::stof(cmd); + } else if (arg == "--repeat_last_n") { + params.repeat_last_n = std::stoi(cmd); + } else if (arg == "--repeat_penalty") { + params.repeat_penalty = std::stof(cmd); + } else if (arg == "-b" || arg == "--batch_size") { + params.n_batch = std::stoi(cmd); + params.n_batch = std::min(512, params.n_batch); + } else if (arg == "-r" || arg == "--reverse-prompt") { + params.antiprompt.push_back(cmd); + } else if (arg == "--keep") { + params.n_keep = std::stoi(cmd); + } else if (arg == "stats") { + llama_print_timings(ctx); + } else { + printf("invalid argument parsed: %s\n please use -n, --top_k, -c, --top_p, --temp, --repeat_last_n, --repeat_penalty, -b, -r, --keep or stats", arg.c_str()); + } + printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", + params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + } +} + int main(int argc, char ** argv) { gpt_params params; params.model = "models/llama-7B/ggml-model.bin"; @@ -439,51 +486,8 @@ int main(int argc, char ** argv) { } n_remain -= line_inp.size(); - - // check buffer's first 3 chars equal '???' to enter command mode. - if (strncmp(buffer.c_str(), "???", 3) == 0 && buffer.length() > 3) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - std::istringstream command(buffer); - int j = 0; std::string test, arg, cmd; - while (command>>test) { - j++; - if ( j == 2 ) { - arg = test; - } - if ( j == 3 ) { - cmd = test; - } - } - if (arg == "-n" || arg == "--n_predict") { - params.n_predict = std::stoi(cmd); - } else if (arg == "--top_k") { - params.top_k = std::stoi(cmd); - } else if (arg == "-c" || arg == "--ctx_size") { - params.n_ctx = std::stoi(cmd); - } else if (arg == "--top_p") { - params.top_p = std::stof(cmd); - } else if (arg == "--temp") { - params.temp = std::stof(cmd); - } else if (arg == "--repeat_last_n") { - params.repeat_last_n = std::stoi(cmd); - } else if (arg == "--repeat_penalty") { - params.repeat_penalty = std::stof(cmd); - } else if (arg == "-b" || arg == "--batch_size") { - params.n_batch = std::stoi(cmd); - params.n_batch = std::min(512, params.n_batch); - } else if (arg == "-r" || arg == "--reverse-prompt") { - params.antiprompt.push_back(cmd); - } else if (arg == "--keep") { - params.n_keep = std::stoi(cmd); - } else if (arg == "stats") { - llama_print_timings(ctx); - } else { - printf("invalid argument parsed: %s\n please use -n, --top_k, -c, --top_p, --temp, --repeat_last_n, --repeat_penalty, -b, -r, --keep or stats", arg.c_str()); - } - printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", - params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); - printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); - } + + command(buffer, params, n_ctx); } input_noecho = true; // do not echo this again From 315e69fd7f26502e4e635de7186ea6499559323f Mon Sep 17 00:00:00 2001 From: wbpxre150 Date: Fri, 14 Apr 2023 01:58:38 +0800 Subject: [PATCH 5/8] fix code indentation. --- examples/main/main.cpp | 79 +++++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b70b034a7..c19ab0c0b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -49,49 +49,48 @@ void sigint_handler(int signo) { void command(std::string buffer, gpt_params params, const int n_ctx ) { // check buffer's first 3 chars equal '???' to enter command mode. - if (strncmp(buffer.c_str(), "???", 3) == 0 && buffer.length() > 3) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - std::istringstream command(buffer); - int j = 0; std::string test, arg, cmd; - while (command>>test) { - j++; - if ( j == 2 ) { - arg = test; - } - if ( j == 3 ) { - cmd = test; - } + if (buffer.length() <= 3 || strncmp(buffer.c_str(), "???", 3) != 0) return; + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + std::istringstream command(buffer); + int j = 0; std::string test, arg, cmd; + while (command>>test) { + j++; + if ( j == 2 ) { + arg = test; } - if (arg == "-n" || arg == "--n_predict") { - params.n_predict = std::stoi(cmd); - } else if (arg == "--top_k") { - params.top_k = std::stoi(cmd); - } else if (arg == "-c" || arg == "--ctx_size") { - params.n_ctx = std::stoi(cmd); - } else if (arg == "--top_p") { - params.top_p = std::stof(cmd); - } else if (arg == "--temp") { - params.temp = std::stof(cmd); - } else if (arg == "--repeat_last_n") { - params.repeat_last_n = std::stoi(cmd); - } else if (arg == "--repeat_penalty") { - params.repeat_penalty = std::stof(cmd); - } else if (arg == "-b" || arg == "--batch_size") { - params.n_batch = std::stoi(cmd); - params.n_batch = std::min(512, params.n_batch); - } else if (arg == "-r" || arg == "--reverse-prompt") { - params.antiprompt.push_back(cmd); - } else if (arg == "--keep") { - params.n_keep = std::stoi(cmd); - } else if (arg == "stats") { - llama_print_timings(ctx); - } else { - printf("invalid argument parsed: %s\n please use -n, --top_k, -c, --top_p, --temp, --repeat_last_n, --repeat_penalty, -b, -r, --keep or stats", arg.c_str()); + if ( j == 3 ) { + cmd = test; } - printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", - params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); - printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); } + if (arg == "-n" || arg == "--n_predict") { + params.n_predict = std::stoi(cmd); + } else if (arg == "--top_k") { + params.top_k = std::stoi(cmd); + } else if (arg == "-c" || arg == "--ctx_size") { + params.n_ctx = std::stoi(cmd); + } else if (arg == "--top_p") { + params.top_p = std::stof(cmd); + } else if (arg == "--temp") { + params.temp = std::stof(cmd); + } else if (arg == "--repeat_last_n") { + params.repeat_last_n = std::stoi(cmd); + } else if (arg == "--repeat_penalty") { + params.repeat_penalty = std::stof(cmd); + } else if (arg == "-b" || arg == "--batch_size") { + params.n_batch = std::stoi(cmd); + params.n_batch = std::min(512, params.n_batch); + } else if (arg == "-r" || arg == "--reverse-prompt") { + params.antiprompt.push_back(cmd); + } else if (arg == "--keep") { + params.n_keep = std::stoi(cmd); + } else if (arg == "stats") { + llama_print_timings(ctx); + } else { + printf("invalid argument parsed: %s\n please use -n, --top_k, -c, --top_p, --temp, --repeat_last_n, --repeat_penalty, -b, -r, --keep or stats", arg.c_str()); + } + printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", + params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); } int main(int argc, char ** argv) { From aa6bca453fcc7c9800d2778009cce7ba0efa2f43 Mon Sep 17 00:00:00 2001 From: wbpxre150 Date: Fri, 14 Apr 2023 12:36:20 +0800 Subject: [PATCH 6/8] Fix prints. --- examples/main/main.cpp | 80 ++++++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c19ab0c0b..5496a851d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -27,6 +27,7 @@ static console_state con_st; static bool is_interacting = false; +static bool is_command = false; llama_context * ctx; @@ -47,50 +48,52 @@ void sigint_handler(int signo) { } #endif -void command(std::string buffer, gpt_params params, const int n_ctx ) { +int command(std::string buffer, gpt_params params, const int n_ctx ) { // check buffer's first 3 chars equal '???' to enter command mode. - if (buffer.length() <= 3 || strncmp(buffer.c_str(), "???", 3) != 0) return; + if (buffer.length() <= 3 || strncmp(buffer.c_str(), "???", 3) != 0) return 0; set_console_color(con_st, CONSOLE_COLOR_DEFAULT); std::istringstream command(buffer); int j = 0; std::string test, arg, cmd; while (command>>test) { j++; - if ( j == 2 ) { - arg = test; - } - if ( j == 3 ) { - cmd = test; - } + if ( j == 2 ) arg = test; + if ( j == 3 ) cmd = test; } - if (arg == "-n" || arg == "--n_predict") { + if (cmd == "") { + printf("Please enter a command value.\n"); + return 1; + } + if (arg == "n_predict") { params.n_predict = std::stoi(cmd); - } else if (arg == "--top_k") { + } else if (arg == "top_k") { params.top_k = std::stoi(cmd); - } else if (arg == "-c" || arg == "--ctx_size") { + } else if (arg == "ctx_size") { params.n_ctx = std::stoi(cmd); - } else if (arg == "--top_p") { + } else if (arg == "top_p") { params.top_p = std::stof(cmd); - } else if (arg == "--temp") { + } else if (arg == "temp") { params.temp = std::stof(cmd); - } else if (arg == "--repeat_last_n") { + } else if (arg == "repeat_last_n") { params.repeat_last_n = std::stoi(cmd); - } else if (arg == "--repeat_penalty") { + } else if (arg == "repeat_penalty") { params.repeat_penalty = std::stof(cmd); - } else if (arg == "-b" || arg == "--batch_size") { + } else if (arg == "batch_size") { params.n_batch = std::stoi(cmd); params.n_batch = std::min(512, params.n_batch); - } else if (arg == "-r" || arg == "--reverse-prompt") { + } else if (arg == "reverse-prompt") { params.antiprompt.push_back(cmd); - } else if (arg == "--keep") { + } else if (arg == "keep") { params.n_keep = std::stoi(cmd); } else if (arg == "stats") { llama_print_timings(ctx); } else { - printf("invalid argument parsed: %s\n please use -n, --top_k, -c, --top_p, --temp, --repeat_last_n, --repeat_penalty, -b, -r, --keep or stats", arg.c_str()); + printf("Invalid command: %s\nValid options are:\n n_predict, top_k, ctx_size, top_p, temp, repeat_last_n, repeat_penalty, batch_size, reverse-prompt, keep, stats\n", arg.c_str()); + return 1; } printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + return 1; } int main(int argc, char ** argv) { @@ -386,6 +389,8 @@ int main(int argc, char ** argv) { // display text if (!input_noecho) { + // if a command was entered clear the output to stop printing of gibberish. + if (is_command == true) embd.clear(); for (auto id : embd) { printf("%s", llama_token_to_str(ctx, id)); } @@ -469,24 +474,29 @@ int main(int argc, char ** argv) { // Add tokens to embd only if the input buffer is non-empty // Entering a empty line lets the user pass control back if (buffer.length() > 1) { + //check for commands + if (command(buffer, params, n_ctx) == 0) { + // this is not a command, run normally. + is_command = false; + // instruct mode: insert instruction prefix + if (params.instruct && !is_antiprompt) { + n_consumed = embd_inp.size(); + embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); + } - // instruct mode: insert instruction prefix - if (params.instruct && !is_antiprompt) { - n_consumed = embd_inp.size(); - embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); + auto line_inp = ::llama_tokenize(ctx, buffer, false); + embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + + // instruct mode: insert response suffix + if (params.instruct) { + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + } + + n_remain -= line_inp.size(); + } else { + // this was a command, so we need to stop anything more from printing. + is_command = true; } - - auto line_inp = ::llama_tokenize(ctx, buffer, false); - embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); - - // instruct mode: insert response suffix - if (params.instruct) { - embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); - } - - n_remain -= line_inp.size(); - - command(buffer, params, n_ctx); } input_noecho = true; // do not echo this again From e6468f95c19018db297d74bcb09dbc927727e483 Mon Sep 17 00:00:00 2001 From: wbpxre150 Date: Fri, 14 Apr 2023 12:38:56 +0800 Subject: [PATCH 7/8] whitespace --- examples/main/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 5496a851d..c22c7f436 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -498,7 +498,7 @@ int main(int argc, char ** argv) { is_command = true; } } - + input_noecho = true; // do not echo this again } From 10e73b08beab25b3c36159e2d22b082f809ab225 Mon Sep 17 00:00:00 2001 From: wbpxre150 Date: Fri, 14 Apr 2023 12:41:33 +0800 Subject: [PATCH 8/8] fix conflict --- examples/gpt4all.sh | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/gpt4all.sh b/examples/gpt4all.sh index 4b1c15f0d..ff2b02118 100755 --- a/examples/gpt4all.sh +++ b/examples/gpt4all.sh @@ -7,8 +7,9 @@ cd `dirname $0` cd .. -./main --color --instruct --threads 6 \ - --model ./models/gpt4all.bin \ - --batch_size 8 --ctx_size 2048 \ +./main --color --instruct --threads 4 \ + --model ./models/gpt4all-7B/gpt4all-lora-quantized.bin \ + --file ./prompts/alpaca.txt \ + --batch_size 8 --ctx_size 2048 -n -1 \ --repeat_last_n 64 --repeat_penalty 1.3 \ - --n_predict 2048 --temp 0.1 --top_k 50 --top_p 0.99 --mlock + --n_predict 128 --temp 0.1 --top_k 40 --top_p 0.95 \ No newline at end of file