diff --git a/examples/simple-inference/simple-inference.cpp b/examples/simple-inference/simple-inference.cpp index 835d05b61..c8c56ae34 100644 --- a/examples/simple-inference/simple-inference.cpp +++ b/examples/simple-inference/simple-inference.cpp @@ -10,6 +10,7 @@ #include "build-info.h" #include "grammar-parser.h" +#include #include #include #include @@ -38,12 +39,7 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static llama_context ** g_ctx; -static llama_model ** g_model; -static gpt_params * g_params; -static std::vector * g_input_tokens; -static std::ostringstream * g_output_ss; -static std::vector * g_output_tokens; +static std::atomic interrupted {false}; void write_logfile( const llama_context * ctx, const gpt_params & params, const llama_model * model, @@ -91,11 +87,7 @@ void write_logfile( #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { if (signo == SIGINT) { - console::cleanup(); - printf("\n"); - llama_print_timings(*g_ctx); - write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); - _exit(130); + interrupted.store(true); } } #endif @@ -173,9 +165,6 @@ bool initialize(llama_context **ctx_p, llama_model **model_p, gpt_params & param LOG("%s: llama backend init\n", __func__); llama_backend_init(params.numa); - g_model = model_p; - g_ctx = ctx_p; - // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); std::tie(*model_p, *ctx_p) = llama_init_from_gpt_params(params); @@ -275,7 +264,7 @@ bool initialize(llama_context **ctx_p, llama_model **model_p, gpt_params & param bool feed_prompt(llama_context *ctx, const gpt_params * params, llama_token * tokens, int tokens_len, int n_past) { console::set_display(console::prompt); - while (tokens_len > 0) { + while (tokens_len > 0 && interrupted.load() == false) { const int this_chunk_size = std::min(tokens_len, params->n_batch); if (llama_eval(ctx, tokens, this_chunk_size, n_past, params->n_threads)) { @@ -302,7 +291,6 @@ bool feed_prompt(llama_context *ctx, const gpt_params * params, llama_token * to int main(int argc, char ** argv) { gpt_params params; - g_params = ¶ms; if (gpt_params_parse(argc, argv, params) == false) { return 1; @@ -329,10 +317,7 @@ int main(int argc, char ** argv) { const int n_ctx = llama_n_ctx(ctx); int n_remain = params.n_predict; - - std::vector input_tokens; g_input_tokens = &input_tokens; - std::vector output_tokens; g_output_tokens = &output_tokens; - std::ostringstream output_ss; g_output_ss = &output_ss; + std::vector input_tokens; { LOG("warming up the model with an empty run\n"); @@ -356,23 +341,10 @@ int main(int argc, char ** argv) { std::vector candidates; candidates.reserve(llama_n_vocab(ctx)); - // Required to match output from main example with a specific seed - but why? - if (false) { - llama_token id = llama_sample_token(ctx, NULL, grammar, params, last_tokens, candidates); - if (llama_eval(ctx, &id, 1, last_tokens.size(), params.n_threads)) { - LOG_TEE("%s : failed to eval\n", __func__); - return 1; - } - const std::string token_str = llama_token_to_piece(ctx, id); - fputs(token_str.c_str(), stdout); - fflush(stdout); - } - - while (n_remain > 0) { + while (n_remain > 0 && interrupted.load() == false) { const llama_token id = llama_sample_token(ctx, NULL, grammar, params, last_tokens, candidates); last_tokens.push_back(id); - output_tokens.push_back(id); --n_remain; LOG("n_remain: %d\n", n_remain); @@ -384,8 +356,6 @@ int main(int argc, char ** argv) { } const std::string token_str = llama_token_to_piece(ctx, id); - - output_ss << token_str; fputs(token_str.c_str(), stdout); fflush(stdout); @@ -396,6 +366,22 @@ int main(int argc, char ** argv) { } } + std::vector output_tokens; + std::ostringstream output_ss; + const size_t prompt_size = prompt_tokens.size(); + + for (size_t i = 0; i < last_tokens.size(); i++) { + const std::string token_str = llama_token_to_piece(ctx, last_tokens[i]); + if (i >= prompt_size) { + output_ss << token_str; + output_tokens.push_back(last_tokens[i]); + } + + } + + console::cleanup(); + printf("\n"); + llama_print_timings(ctx); write_logfile(ctx, params, model, prompt_tokens, output_ss.str(), output_tokens); @@ -411,5 +397,5 @@ int main(int argc, char ** argv) { LOG_TEE("Log end\n") #endif // LOG_DISABLE_LOGS - return 0; + return interrupted.load() ? 130 : 0; }