From c7e1427bf16ad51d19e03c948d79fb79c97ec0dd Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Wed, 13 Sep 2023 12:32:15 -0600 Subject: [PATCH] Apply some code cleanup suggestions. Thanks! --- examples/simple-inference/simple-inference.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/simple-inference/simple-inference.cpp b/examples/simple-inference/simple-inference.cpp index c8c56ae34..ea96926bd 100644 --- a/examples/simple-inference/simple-inference.cpp +++ b/examples/simple-inference/simple-inference.cpp @@ -219,7 +219,7 @@ bool initialize(llama_context **ctx_p, llama_model **model_p, gpt_params & param LOG_TEE("\n"); } - #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; sigint_action.sa_handler = sigint_handler; sigemptyset (&sigint_action.sa_mask); @@ -264,7 +264,7 @@ bool initialize(llama_context **ctx_p, llama_model **model_p, gpt_params & param bool feed_prompt(llama_context *ctx, const gpt_params * params, llama_token * tokens, int tokens_len, int n_past) { console::set_display(console::prompt); - while (tokens_len > 0 && interrupted.load() == false) { + while (tokens_len > 0 && !interrupted) { const int this_chunk_size = std::min(tokens_len, params->n_batch); if (llama_eval(ctx, tokens, this_chunk_size, n_past, params->n_threads)) { @@ -341,7 +341,7 @@ int main(int argc, char ** argv) { std::vector candidates; candidates.reserve(llama_n_vocab(ctx)); - while (n_remain > 0 && interrupted.load() == false) { + while (n_remain > 0 && !interrupted) { const llama_token id = llama_sample_token(ctx, NULL, grammar, params, last_tokens, candidates); last_tokens.push_back(id); @@ -369,6 +369,7 @@ int main(int argc, char ** argv) { std::vector output_tokens; std::ostringstream output_ss; const size_t prompt_size = prompt_tokens.size(); + output_tokens.reserve(last_tokens.size() - prompt_size); for (size_t i = 0; i < last_tokens.size(); i++) { const std::string token_str = llama_token_to_piece(ctx, last_tokens[i]); @@ -397,5 +398,5 @@ int main(int argc, char ** argv) { LOG_TEE("Log end\n") #endif // LOG_DISABLE_LOGS - return interrupted.load() ? 130 : 0; + return interrupted ? 130 : 0; }