add command line to interactive mode. You can specify diffrent values for everything in params during rumtime.

This commit is contained in:
wbpxre150 2023-04-13 04:58:35 +08:00
parent 47d809c692
commit 0d903962e5

View file

@ -14,6 +14,7 @@
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <sstream>
#include <vector> #include <vector>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
@ -167,7 +168,7 @@ int main(int argc, char ** argv) {
// in instruct mode, we inject a prefix and a suffix to each input by the user // in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) { if (params.instruct) {
params.interactive_start = true; params.interactive_start = true;
params.antiprompt.push_back("### Instruction:\n\n"); params.antiprompt.push_back("###");
} }
// enable interactive mode if reverse prompt or interactive start is specified // enable interactive mode if reverse prompt or interactive start is specified
@ -438,8 +439,53 @@ int main(int argc, char ** argv) {
} }
n_remain -= line_inp.size(); n_remain -= line_inp.size();
}
// check buffer's first 3 chars equal '???' to enter command mode.
if (strncmp(buffer.c_str(), "???", 3) == 0 && buffer.length() > 3) {
std::istringstream command(buffer);
int j = 0; std::string test, arg, cmd;
while (command>>test) {
j++;
if ( j == 2 ) {
arg = test;
}
if ( j == 3 ) {
cmd = test;
}
}
if (arg == "-n" || arg == "--n_predict") {
params.n_predict = std::stoi(cmd);
} else if (arg == "--top_k") {
params.top_k = std::stoi(cmd);
} else if (arg == "-c" || arg == "--ctx_size") {
params.n_ctx = std::stoi(cmd);
} else if (arg == "--top_p") {
params.top_p = std::stof(cmd);
} else if (arg == "--temp") {
params.temp = std::stof(cmd);
} else if (arg == "--repeat_last_n") {
params.repeat_last_n = std::stoi(cmd);
} else if (arg == "--repeat_penalty") {
params.repeat_penalty = std::stof(cmd);
} else if (arg == "-b" || arg == "--batch_size") {
params.n_batch = std::stoi(cmd);
params.n_batch = std::min(512, params.n_batch);
} else if (arg == "-r" || arg == "--reverse-prompt") {
params.antiprompt.push_back(cmd);
} else if (arg == "--keep") {
params.n_keep = std::stoi(cmd);
} else if (arg == "stats") {
llama_print_timings(ctx);
} else {
printf("invalid argument parsed: %s\n", arg.c_str());
}
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
}
}
input_noecho = true; // do not echo this again input_noecho = true; // do not echo this again
} }