Remove direct access to std streams from llama_main

The goal is to allow running llama_main while connected to other
streams, such as TCP sockets.

Signed-off-by: Thiago Padilha <thiago@padilha.cc>
This commit is contained in:
Thiago Padilha 2023-03-18 12:20:20 -03:00
parent 8b9a9dc49f
commit 9ed33b37de
No known key found for this signature in database
GPG key ID: 309C78E5ED1B3D5E
3 changed files with 38 additions and 30 deletions

View file

@ -736,13 +736,16 @@ int llama_main(
gpt_vocab vocab, gpt_vocab vocab,
llama_model model, llama_model model,
int64_t t_load_us, int64_t t_load_us,
int64_t t_main_start_us) { int64_t t_main_start_us,
std::istream & instream,
FILE *outstream,
FILE *errstream) {
if (params.seed < 0) { if (params.seed < 0) {
params.seed = time(NULL); params.seed = time(NULL);
} }
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); fprintf(errstream, "%s: seed = %d\n", __func__, params.seed);
std::mt19937 rng(params.seed); std::mt19937 rng(params.seed);
if (params.random_prompt) { if (params.random_prompt) {
@ -788,13 +791,13 @@ int llama_main(
params.interactive = true; params.interactive = true;
} }
fprintf(stderr, "\n"); fprintf(errstream, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); fprintf(errstream, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); fprintf(errstream, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
for (int i = 0; i < (int) embd_inp.size(); i++) { for (int i = 0; i < (int) embd_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str()); fprintf(errstream, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
} }
fprintf(stderr, "\n"); fprintf(errstream, "\n");
if (params.interactive) { if (params.interactive) {
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action; struct sigaction sigint_action;
@ -806,22 +809,22 @@ int llama_main(
signal(SIGINT, sigint_handler); signal(SIGINT, sigint_handler);
#endif #endif
fprintf(stderr, "%s: interactive mode on.\n", __func__); fprintf(errstream, "%s: interactive mode on.\n", __func__);
if(antipromptv_inp.size()) { if(antipromptv_inp.size()) {
for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) { for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) {
auto antiprompt_inp = antipromptv_inp.at(apindex); auto antiprompt_inp = antipromptv_inp.at(apindex);
fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str()); fprintf(errstream, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str());
fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); fprintf(errstream, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
for (int i = 0; i < (int) antiprompt_inp.size(); i++) { for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); fprintf(errstream, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
} }
fprintf(stderr, "\n"); fprintf(errstream, "\n");
} }
} }
} }
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); fprintf(errstream, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "\n\n"); fprintf(errstream, "\n\n");
std::vector<gpt_vocab::id> embd; std::vector<gpt_vocab::id> embd;
@ -834,7 +837,7 @@ int llama_main(
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
if (params.interactive) { if (params.interactive) {
fprintf(stderr, "== Running in interactive mode. ==\n" fprintf(errstream, "== Running in interactive mode. ==\n"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
" - Press Ctrl+C to interject at any time.\n" " - Press Ctrl+C to interject at any time.\n"
#endif #endif
@ -850,7 +853,7 @@ int llama_main(
// set the color for the prompt which will be output initially // set the color for the prompt which will be output initially
if (params.use_color) { if (params.use_color) {
printf(ANSI_COLOR_YELLOW); fprintf(outstream, ANSI_COLOR_YELLOW);
} }
while (remaining_tokens > 0 || params.interactive) { while (remaining_tokens > 0 || params.interactive) {
@ -859,7 +862,7 @@ int llama_main(
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
fprintf(stderr, "Failed to predict\n"); fprintf(errstream, "Failed to predict\n");
return 1; return 1;
} }
@ -920,9 +923,9 @@ int llama_main(
// display text // display text
if (!input_noecho) { if (!input_noecho) {
for (auto id : embd) { for (auto id : embd) {
printf("%s", vocab.id_to_token[id].c_str()); fprintf(outstream, "%s", vocab.id_to_token[id].c_str());
} }
fflush(stdout); fflush(outstream);
} }
// reset color to default if we there is no pending user input // reset color to default if we there is no pending user input
if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) { if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) {
@ -954,7 +957,7 @@ int llama_main(
std::string line; std::string line;
bool another_line = true; bool another_line = true;
do { do {
std::getline(std::cin, line); std::getline(instream, line);
if (line.empty() || line.back() != '\\') { if (line.empty() || line.back() != '\\') {
another_line = false; another_line = false;
} else { } else {
@ -983,7 +986,7 @@ int llama_main(
if (params.interactive) { if (params.interactive) {
is_interacting = true; is_interacting = true;
} else { } else {
fprintf(stderr, " [end of text]\n"); fprintf(errstream, " [end of text]\n");
break; break;
} }
} }
@ -1003,18 +1006,18 @@ int llama_main(
{ {
const int64_t t_main_end_us = ggml_time_us(); const int64_t t_main_end_us = ggml_time_us();
fprintf(stderr, "\n\n"); fprintf(errstream, "\n\n");
fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, mem_per_token); fprintf(errstream, "%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); fprintf(errstream, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); fprintf(errstream, "%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); fprintf(errstream, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); fprintf(errstream, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
} }
ggml_free(model.ctx); ggml_free(model.ctx);
if (params.use_color) { if (params.use_color) {
printf(ANSI_COLOR_RESET); fprintf(outstream, ANSI_COLOR_RESET);
} }
return 0; return 0;

View file

@ -64,5 +64,8 @@ int llama_main(
gpt_vocab vocab, gpt_vocab vocab,
llama_model model, llama_model model,
int64_t t_load_us, int64_t t_load_us,
int64_t t_main_start_us); int64_t t_main_start_us,
std::istream & instream,
FILE *outstream,
FILE *errstream);
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type); bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type);

View file

@ -2,6 +2,8 @@
#include "utils.h" #include "utils.h"
#include "llama.h" #include "llama.h"
#include <iostream>
const char * llama_print_system_info(void) { const char * llama_print_system_info(void) {
static std::string s; static std::string s;
@ -63,5 +65,5 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
} }
return llama_main(params, vocab, model, t_main_start_us, t_load_us); return llama_main(params, vocab, model, t_main_start_us, t_load_us, std::cin, stdout, stderr);
} }