Clean up interrupt handling, avoid globals
This commit is contained in:
parent
62c5c6f5c3
commit
d75698c3b0
1 changed files with 23 additions and 37 deletions
|
@ -10,6 +10,7 @@
|
||||||
#include "build-info.h"
|
#include "build-info.h"
|
||||||
#include "grammar-parser.h"
|
#include "grammar-parser.h"
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
@ -38,12 +39,7 @@
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static llama_context ** g_ctx;
|
static std::atomic<bool> interrupted {false};
|
||||||
static llama_model ** g_model;
|
|
||||||
static gpt_params * g_params;
|
|
||||||
static std::vector<llama_token> * g_input_tokens;
|
|
||||||
static std::ostringstream * g_output_ss;
|
|
||||||
static std::vector<llama_token> * g_output_tokens;
|
|
||||||
|
|
||||||
void write_logfile(
|
void write_logfile(
|
||||||
const llama_context * ctx, const gpt_params & params, const llama_model * model,
|
const llama_context * ctx, const gpt_params & params, const llama_model * model,
|
||||||
|
@ -91,11 +87,7 @@ void write_logfile(
|
||||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||||
void sigint_handler(int signo) {
|
void sigint_handler(int signo) {
|
||||||
if (signo == SIGINT) {
|
if (signo == SIGINT) {
|
||||||
console::cleanup();
|
interrupted.store(true);
|
||||||
printf("\n");
|
|
||||||
llama_print_timings(*g_ctx);
|
|
||||||
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
|
||||||
_exit(130);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -173,9 +165,6 @@ bool initialize(llama_context **ctx_p, llama_model **model_p, gpt_params & param
|
||||||
LOG("%s: llama backend init\n", __func__);
|
LOG("%s: llama backend init\n", __func__);
|
||||||
llama_backend_init(params.numa);
|
llama_backend_init(params.numa);
|
||||||
|
|
||||||
g_model = model_p;
|
|
||||||
g_ctx = ctx_p;
|
|
||||||
|
|
||||||
// load the model and apply lora adapter, if any
|
// load the model and apply lora adapter, if any
|
||||||
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||||
std::tie(*model_p, *ctx_p) = llama_init_from_gpt_params(params);
|
std::tie(*model_p, *ctx_p) = llama_init_from_gpt_params(params);
|
||||||
|
@ -275,7 +264,7 @@ bool initialize(llama_context **ctx_p, llama_model **model_p, gpt_params & param
|
||||||
|
|
||||||
bool feed_prompt(llama_context *ctx, const gpt_params * params, llama_token * tokens, int tokens_len, int n_past) {
|
bool feed_prompt(llama_context *ctx, const gpt_params * params, llama_token * tokens, int tokens_len, int n_past) {
|
||||||
console::set_display(console::prompt);
|
console::set_display(console::prompt);
|
||||||
while (tokens_len > 0) {
|
while (tokens_len > 0 && interrupted.load() == false) {
|
||||||
const int this_chunk_size = std::min(tokens_len, params->n_batch);
|
const int this_chunk_size = std::min(tokens_len, params->n_batch);
|
||||||
|
|
||||||
if (llama_eval(ctx, tokens, this_chunk_size, n_past, params->n_threads)) {
|
if (llama_eval(ctx, tokens, this_chunk_size, n_past, params->n_threads)) {
|
||||||
|
@ -302,7 +291,6 @@ bool feed_prompt(llama_context *ctx, const gpt_params * params, llama_token * to
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
g_params = ¶ms;
|
|
||||||
|
|
||||||
if (gpt_params_parse(argc, argv, params) == false) {
|
if (gpt_params_parse(argc, argv, params) == false) {
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -329,10 +317,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
int n_remain = params.n_predict;
|
int n_remain = params.n_predict;
|
||||||
|
std::vector<int> input_tokens;
|
||||||
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
|
||||||
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
|
||||||
std::ostringstream output_ss; g_output_ss = &output_ss;
|
|
||||||
|
|
||||||
{
|
{
|
||||||
LOG("warming up the model with an empty run\n");
|
LOG("warming up the model with an empty run\n");
|
||||||
|
@ -356,23 +341,10 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(llama_n_vocab(ctx));
|
candidates.reserve(llama_n_vocab(ctx));
|
||||||
|
|
||||||
// Required to match output from main example with a specific seed - but why?
|
while (n_remain > 0 && interrupted.load() == false) {
|
||||||
if (false) {
|
|
||||||
llama_token id = llama_sample_token(ctx, NULL, grammar, params, last_tokens, candidates);
|
|
||||||
if (llama_eval(ctx, &id, 1, last_tokens.size(), params.n_threads)) {
|
|
||||||
LOG_TEE("%s : failed to eval\n", __func__);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
|
||||||
fputs(token_str.c_str(), stdout);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
|
|
||||||
while (n_remain > 0) {
|
|
||||||
const llama_token id = llama_sample_token(ctx, NULL, grammar, params, last_tokens, candidates);
|
const llama_token id = llama_sample_token(ctx, NULL, grammar, params, last_tokens, candidates);
|
||||||
|
|
||||||
last_tokens.push_back(id);
|
last_tokens.push_back(id);
|
||||||
output_tokens.push_back(id);
|
|
||||||
--n_remain;
|
--n_remain;
|
||||||
|
|
||||||
LOG("n_remain: %d\n", n_remain);
|
LOG("n_remain: %d\n", n_remain);
|
||||||
|
@ -384,8 +356,6 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||||
|
|
||||||
output_ss << token_str;
|
|
||||||
fputs(token_str.c_str(), stdout);
|
fputs(token_str.c_str(), stdout);
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
|
@ -396,6 +366,22 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> output_tokens;
|
||||||
|
std::ostringstream output_ss;
|
||||||
|
const size_t prompt_size = prompt_tokens.size();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < last_tokens.size(); i++) {
|
||||||
|
const std::string token_str = llama_token_to_piece(ctx, last_tokens[i]);
|
||||||
|
if (i >= prompt_size) {
|
||||||
|
output_ss << token_str;
|
||||||
|
output_tokens.push_back(last_tokens[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
console::cleanup();
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
write_logfile(ctx, params, model, prompt_tokens, output_ss.str(), output_tokens);
|
write_logfile(ctx, params, model, prompt_tokens, output_ss.str(), output_tokens);
|
||||||
|
|
||||||
|
@ -411,5 +397,5 @@ int main(int argc, char ** argv) {
|
||||||
LOG_TEE("Log end\n")
|
LOG_TEE("Log end\n")
|
||||||
#endif // LOG_DISABLE_LOGS
|
#endif // LOG_DISABLE_LOGS
|
||||||
|
|
||||||
return 0;
|
return interrupted.load() ? 130 : 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue