refector code into function.
This commit is contained in:
parent
a1b4e48ba2
commit
fa651909bb
2 changed files with 50 additions and 46 deletions
|
@ -11,4 +11,4 @@ cd ..
|
|||
--model ./models/gpt4all.bin \
|
||||
--batch_size 8 --ctx_size 2048 \
|
||||
--repeat_last_n 64 --repeat_penalty 1.3 \
|
||||
--n_predict 2048 --temp 0.1 --top_k 40 --top_p 0.95 --mlock
|
||||
--n_predict 2048 --temp 0.1 --top_k 50 --top_p 0.99 --mlock
|
||||
|
|
|
@ -47,6 +47,53 @@ void sigint_handler(int signo) {
|
|||
}
|
||||
#endif
|
||||
|
||||
void command(std::string buffer, gpt_params params, const int n_ctx ) {
|
||||
// check buffer's first 3 chars equal '???' to enter command mode.
|
||||
if (strncmp(buffer.c_str(), "???", 3) == 0 && buffer.length() > 3) {
|
||||
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
|
||||
std::istringstream command(buffer);
|
||||
int j = 0; std::string test, arg, cmd;
|
||||
while (command>>test) {
|
||||
j++;
|
||||
if ( j == 2 ) {
|
||||
arg = test;
|
||||
}
|
||||
if ( j == 3 ) {
|
||||
cmd = test;
|
||||
}
|
||||
}
|
||||
if (arg == "-n" || arg == "--n_predict") {
|
||||
params.n_predict = std::stoi(cmd);
|
||||
} else if (arg == "--top_k") {
|
||||
params.top_k = std::stoi(cmd);
|
||||
} else if (arg == "-c" || arg == "--ctx_size") {
|
||||
params.n_ctx = std::stoi(cmd);
|
||||
} else if (arg == "--top_p") {
|
||||
params.top_p = std::stof(cmd);
|
||||
} else if (arg == "--temp") {
|
||||
params.temp = std::stof(cmd);
|
||||
} else if (arg == "--repeat_last_n") {
|
||||
params.repeat_last_n = std::stoi(cmd);
|
||||
} else if (arg == "--repeat_penalty") {
|
||||
params.repeat_penalty = std::stof(cmd);
|
||||
} else if (arg == "-b" || arg == "--batch_size") {
|
||||
params.n_batch = std::stoi(cmd);
|
||||
params.n_batch = std::min(512, params.n_batch);
|
||||
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
||||
params.antiprompt.push_back(cmd);
|
||||
} else if (arg == "--keep") {
|
||||
params.n_keep = std::stoi(cmd);
|
||||
} else if (arg == "stats") {
|
||||
llama_print_timings(ctx);
|
||||
} else {
|
||||
printf("invalid argument parsed: %s\n please use -n, --top_k, -c, --top_p, --temp, --repeat_last_n, --repeat_penalty, -b, -r, --keep or stats", arg.c_str());
|
||||
}
|
||||
printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
|
||||
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
|
||||
printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
params.model = "models/llama-7B/ggml-model.bin";
|
||||
|
@ -440,50 +487,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
n_remain -= line_inp.size();
|
||||
|
||||
// check buffer's first 3 chars equal '???' to enter command mode.
|
||||
if (strncmp(buffer.c_str(), "???", 3) == 0 && buffer.length() > 3) {
|
||||
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
|
||||
std::istringstream command(buffer);
|
||||
int j = 0; std::string test, arg, cmd;
|
||||
while (command>>test) {
|
||||
j++;
|
||||
if ( j == 2 ) {
|
||||
arg = test;
|
||||
}
|
||||
if ( j == 3 ) {
|
||||
cmd = test;
|
||||
}
|
||||
}
|
||||
if (arg == "-n" || arg == "--n_predict") {
|
||||
params.n_predict = std::stoi(cmd);
|
||||
} else if (arg == "--top_k") {
|
||||
params.top_k = std::stoi(cmd);
|
||||
} else if (arg == "-c" || arg == "--ctx_size") {
|
||||
params.n_ctx = std::stoi(cmd);
|
||||
} else if (arg == "--top_p") {
|
||||
params.top_p = std::stof(cmd);
|
||||
} else if (arg == "--temp") {
|
||||
params.temp = std::stof(cmd);
|
||||
} else if (arg == "--repeat_last_n") {
|
||||
params.repeat_last_n = std::stoi(cmd);
|
||||
} else if (arg == "--repeat_penalty") {
|
||||
params.repeat_penalty = std::stof(cmd);
|
||||
} else if (arg == "-b" || arg == "--batch_size") {
|
||||
params.n_batch = std::stoi(cmd);
|
||||
params.n_batch = std::min(512, params.n_batch);
|
||||
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
||||
params.antiprompt.push_back(cmd);
|
||||
} else if (arg == "--keep") {
|
||||
params.n_keep = std::stoi(cmd);
|
||||
} else if (arg == "stats") {
|
||||
llama_print_timings(ctx);
|
||||
} else {
|
||||
printf("invalid argument parsed: %s\n please use -n, --top_k, -c, --top_p, --temp, --repeat_last_n, --repeat_penalty, -b, -r, --keep or stats", arg.c_str());
|
||||
}
|
||||
printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
|
||||
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
|
||||
printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
}
|
||||
command(buffer, params, n_ctx);
|
||||
}
|
||||
|
||||
input_noecho = true; // do not echo this again
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue