This commit is contained in:
wbpxre150 2024-06-12 21:46:49 +00:00 committed by GitHub
commit 430a2b178a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -14,6 +14,7 @@
#include <sstream>
#include <string>
#include <vector>
#include <sstream>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
@ -31,6 +32,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
llama_context * ctx;
static llama_context ** g_ctx;
static llama_model ** g_model;
static gpt_params * g_params;
@ -38,6 +41,7 @@ static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
static bool is_command = false;
static bool file_exists(const std::string &path) {
std::ifstream f(path.c_str());
@ -111,6 +115,54 @@ static void sigint_handler(int signo) {
}
#endif
int command(std::string buffer, gpt_params &params, const int n_ctx ) {
// check buffer's first 3 chars equal '???' to enter command mode.
if (buffer.length() <= 3 || strncmp(buffer.c_str(), "???", 3) != 0) return 0;
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
std::istringstream command(buffer);
int j = 0; std::string test, arg, cmd;
while (command>>test) {
j++;
if ( j == 2 ) arg = test;
if ( j == 3 ) cmd = test;
}
if (cmd == "") {
printf("Please enter a command value.\n");
return 1;
}
if (arg == "n_predict") {
params.n_predict = std::stoi(cmd);
} else if (arg == "top_k") {
params.top_k = std::stoi(cmd);
} else if (arg == "ctx_size") {
params.n_ctx = std::stoi(cmd);
} else if (arg == "top_p") {
params.top_p = std::stof(cmd);
} else if (arg == "temp") {
params.temp = std::stof(cmd);
} else if (arg == "repeat_last_n") {
params.repeat_last_n = std::stoi(cmd);
} else if (arg == "repeat_penalty") {
params.repeat_penalty = std::stof(cmd);
} else if (arg == "batch_size") {
params.n_batch = std::stoi(cmd);
params.n_batch = std::min(512, params.n_batch);
} else if (arg == "reverse-prompt") {
params.antiprompt.push_back(cmd);
} else if (arg == "keep") {
params.n_keep = std::stoi(cmd);
} else if (arg == "stats") {
llama_print_timings(ctx);
} else {
printf("Invalid command: %s\nValid options are:\n n_predict, top_k, ctx_size, top_p, temp, repeat_last_n, repeat_penalty, batch_size, reverse-prompt, keep, stats\n", arg.c_str());
return 1;
}
printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
return 1;
}
static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) {
(void) level;
(void) user_data;
@ -186,7 +238,6 @@ int main(int argc, char ** argv) {
LOG("%s: llama backend init\n", __func__);
llama_backend_init();
llama_numa_init(params.numa);
llama_model * model;
llama_context * ctx;
llama_context * ctx_guidance = NULL;
@ -711,6 +762,8 @@ int main(int argc, char ** argv) {
// display text
if (input_echo && display) {
// if a command was entered clear the output to stop printing of gibberish.
if (is_command == true) embd.clear();
for (auto id : embd) {
const std::string token_str = llama_token_to_piece(ctx, id, params.special);
@ -834,6 +887,14 @@ int main(int argc, char ** argv) {
// Add tokens to embd only if the input buffer is non-empty
// Entering a empty line lets the user pass control back
if (buffer.length() > 1) {
if (command(buffer, params, n_ctx) == 0) {
// this is not a command, run normally.
is_command = false;
} else {
// this was a command, so we need to stop anything more from printing.
is_command = true;
}
// append input suffix if any
if (!params.input_suffix.empty() && !params.conversation) {
LOG("appending input suffix: '%s'\n", params.input_suffix.c_str());