diff --git a/llama.cpp b/llama.cpp index bafb4d866..77bfc3e76 100644 --- a/llama.cpp +++ b/llama.cpp @@ -736,13 +736,16 @@ int llama_main( gpt_vocab vocab, llama_model model, int64_t t_load_us, - int64_t t_main_start_us) { + int64_t t_main_start_us, + std::istream & instream, + FILE *outstream, + FILE *errstream) { if (params.seed < 0) { params.seed = time(NULL); } - fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); + fprintf(errstream, "%s: seed = %d\n", __func__, params.seed); std::mt19937 rng(params.seed); if (params.random_prompt) { @@ -788,13 +791,13 @@ int llama_main( params.interactive = true; } - fprintf(stderr, "\n"); - fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); - fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + fprintf(errstream, "\n"); + fprintf(errstream, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + fprintf(errstream, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); for (int i = 0; i < (int) embd_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str()); + fprintf(errstream, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str()); } - fprintf(stderr, "\n"); + fprintf(errstream, "\n"); if (params.interactive) { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; @@ -806,22 +809,22 @@ int llama_main( signal(SIGINT, sigint_handler); #endif - fprintf(stderr, "%s: interactive mode on.\n", __func__); + fprintf(errstream, "%s: interactive mode on.\n", __func__); if(antipromptv_inp.size()) { for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) { auto antiprompt_inp = antipromptv_inp.at(apindex); - fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str()); - fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); + fprintf(errstream, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str()); + fprintf(errstream, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); for (int i = 0; i < (int) antiprompt_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); + fprintf(errstream, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); } - fprintf(stderr, "\n"); + fprintf(errstream, "\n"); } } } - fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); - fprintf(stderr, "\n\n"); + fprintf(errstream, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + fprintf(errstream, "\n\n"); std::vector embd; @@ -834,7 +837,7 @@ int llama_main( std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); if (params.interactive) { - fprintf(stderr, "== Running in interactive mode. ==\n" + fprintf(errstream, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif @@ -850,7 +853,7 @@ int llama_main( // set the color for the prompt which will be output initially if (params.use_color) { - printf(ANSI_COLOR_YELLOW); + fprintf(outstream, ANSI_COLOR_YELLOW); } while (remaining_tokens > 0 || params.interactive) { @@ -859,7 +862,7 @@ int llama_main( const int64_t t_start_us = ggml_time_us(); if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { - fprintf(stderr, "Failed to predict\n"); + fprintf(errstream, "Failed to predict\n"); return 1; } @@ -920,9 +923,9 @@ int llama_main( // display text if (!input_noecho) { for (auto id : embd) { - printf("%s", vocab.id_to_token[id].c_str()); + fprintf(outstream, "%s", vocab.id_to_token[id].c_str()); } - fflush(stdout); + fflush(outstream); } // reset color to default if we there is no pending user input if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) { @@ -954,7 +957,7 @@ int llama_main( std::string line; bool another_line = true; do { - std::getline(std::cin, line); + std::getline(instream, line); if (line.empty() || line.back() != '\\') { another_line = false; } else { @@ -983,7 +986,7 @@ int llama_main( if (params.interactive) { is_interacting = true; } else { - fprintf(stderr, " [end of text]\n"); + fprintf(errstream, " [end of text]\n"); break; } } @@ -1003,18 +1006,18 @@ int llama_main( { const int64_t t_main_end_us = ggml_time_us(); - fprintf(stderr, "\n\n"); - fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, mem_per_token); - fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); - fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); - fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); - fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + fprintf(errstream, "\n\n"); + fprintf(errstream, "%s: mem per token = %8zu bytes\n", __func__, mem_per_token); + fprintf(errstream, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); + fprintf(errstream, "%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); + fprintf(errstream, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); + fprintf(errstream, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); } ggml_free(model.ctx); if (params.use_color) { - printf(ANSI_COLOR_RESET); + fprintf(outstream, ANSI_COLOR_RESET); } return 0; diff --git a/llama.h b/llama.h index cf777c6b6..e8d1d43a5 100644 --- a/llama.h +++ b/llama.h @@ -64,5 +64,8 @@ int llama_main( gpt_vocab vocab, llama_model model, int64_t t_load_us, - int64_t t_main_start_us); + int64_t t_main_start_us, + std::istream & instream, + FILE *outstream, + FILE *errstream); bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type); diff --git a/main.cpp b/main.cpp index 5d3cddeaa..0079c3975 100644 --- a/main.cpp +++ b/main.cpp @@ -2,6 +2,8 @@ #include "utils.h" #include "llama.h" +#include + const char * llama_print_system_info(void) { static std::string s; @@ -63,5 +65,5 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } - return llama_main(params, vocab, model, t_main_start_us, t_load_us); + return llama_main(params, vocab, model, t_main_start_us, t_load_us, std::cin, stdout, stderr); }