Clean up interrupt handling, avoid globals

This commit is contained in:
KerfuffleV2 2023-09-13 05:56:16 -06:00
parent 62c5c6f5c3
commit d75698c3b0

View file

@ -10,6 +10,7 @@
#include "build-info.h" #include "build-info.h"
#include "grammar-parser.h" #include "grammar-parser.h"
#include <atomic>
#include <cassert> #include <cassert>
#include <cinttypes> #include <cinttypes>
#include <cmath> #include <cmath>
@ -38,12 +39,7 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
static llama_context ** g_ctx; static std::atomic<bool> interrupted {false};
static llama_model ** g_model;
static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
void write_logfile( void write_logfile(
const llama_context * ctx, const gpt_params & params, const llama_model * model, const llama_context * ctx, const gpt_params & params, const llama_model * model,
@ -91,11 +87,7 @@ void write_logfile(
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) { void sigint_handler(int signo) {
if (signo == SIGINT) { if (signo == SIGINT) {
console::cleanup(); interrupted.store(true);
printf("\n");
llama_print_timings(*g_ctx);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
_exit(130);
} }
} }
#endif #endif
@ -173,9 +165,6 @@ bool initialize(llama_context **ctx_p, llama_model **model_p, gpt_params & param
LOG("%s: llama backend init\n", __func__); LOG("%s: llama backend init\n", __func__);
llama_backend_init(params.numa); llama_backend_init(params.numa);
g_model = model_p;
g_ctx = ctx_p;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__); LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(*model_p, *ctx_p) = llama_init_from_gpt_params(params); std::tie(*model_p, *ctx_p) = llama_init_from_gpt_params(params);
@ -275,7 +264,7 @@ bool initialize(llama_context **ctx_p, llama_model **model_p, gpt_params & param
bool feed_prompt(llama_context *ctx, const gpt_params * params, llama_token * tokens, int tokens_len, int n_past) { bool feed_prompt(llama_context *ctx, const gpt_params * params, llama_token * tokens, int tokens_len, int n_past) {
console::set_display(console::prompt); console::set_display(console::prompt);
while (tokens_len > 0) { while (tokens_len > 0 && interrupted.load() == false) {
const int this_chunk_size = std::min(tokens_len, params->n_batch); const int this_chunk_size = std::min(tokens_len, params->n_batch);
if (llama_eval(ctx, tokens, this_chunk_size, n_past, params->n_threads)) { if (llama_eval(ctx, tokens, this_chunk_size, n_past, params->n_threads)) {
@ -302,7 +291,6 @@ bool feed_prompt(llama_context *ctx, const gpt_params * params, llama_token * to
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
g_params = &params;
if (gpt_params_parse(argc, argv, params) == false) { if (gpt_params_parse(argc, argv, params) == false) {
return 1; return 1;
@ -329,10 +317,7 @@ int main(int argc, char ** argv) {
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
int n_remain = params.n_predict; int n_remain = params.n_predict;
std::vector<int> input_tokens;
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
std::ostringstream output_ss; g_output_ss = &output_ss;
{ {
LOG("warming up the model with an empty run\n"); LOG("warming up the model with an empty run\n");
@ -356,23 +341,10 @@ int main(int argc, char ** argv) {
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(llama_n_vocab(ctx)); candidates.reserve(llama_n_vocab(ctx));
// Required to match output from main example with a specific seed - but why? while (n_remain > 0 && interrupted.load() == false) {
if (false) {
llama_token id = llama_sample_token(ctx, NULL, grammar, params, last_tokens, candidates);
if (llama_eval(ctx, &id, 1, last_tokens.size(), params.n_threads)) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
const std::string token_str = llama_token_to_piece(ctx, id);
fputs(token_str.c_str(), stdout);
fflush(stdout);
}
while (n_remain > 0) {
const llama_token id = llama_sample_token(ctx, NULL, grammar, params, last_tokens, candidates); const llama_token id = llama_sample_token(ctx, NULL, grammar, params, last_tokens, candidates);
last_tokens.push_back(id); last_tokens.push_back(id);
output_tokens.push_back(id);
--n_remain; --n_remain;
LOG("n_remain: %d\n", n_remain); LOG("n_remain: %d\n", n_remain);
@ -384,8 +356,6 @@ int main(int argc, char ** argv) {
} }
const std::string token_str = llama_token_to_piece(ctx, id); const std::string token_str = llama_token_to_piece(ctx, id);
output_ss << token_str;
fputs(token_str.c_str(), stdout); fputs(token_str.c_str(), stdout);
fflush(stdout); fflush(stdout);
@ -396,6 +366,22 @@ int main(int argc, char ** argv) {
} }
} }
std::vector<int> output_tokens;
std::ostringstream output_ss;
const size_t prompt_size = prompt_tokens.size();
for (size_t i = 0; i < last_tokens.size(); i++) {
const std::string token_str = llama_token_to_piece(ctx, last_tokens[i]);
if (i >= prompt_size) {
output_ss << token_str;
output_tokens.push_back(last_tokens[i]);
}
}
console::cleanup();
printf("\n");
llama_print_timings(ctx); llama_print_timings(ctx);
write_logfile(ctx, params, model, prompt_tokens, output_ss.str(), output_tokens); write_logfile(ctx, params, model, prompt_tokens, output_ss.str(), output_tokens);
@ -411,5 +397,5 @@ int main(int argc, char ** argv) {
LOG_TEE("Log end\n") LOG_TEE("Log end\n")
#endif // LOG_DISABLE_LOGS #endif // LOG_DISABLE_LOGS
return 0; return interrupted.load() ? 130 : 0;
} }