Apply some code cleanup suggestions. Thanks!

This commit is contained in:
KerfuffleV2 2023-09-13 12:32:15 -06:00
parent d75698c3b0
commit c7e1427bf1

View file

@ -219,7 +219,7 @@ bool initialize(llama_context **ctx_p, llama_model **model_p, gpt_params & param
LOG_TEE("\n"); LOG_TEE("\n");
} }
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action; struct sigaction sigint_action;
sigint_action.sa_handler = sigint_handler; sigint_action.sa_handler = sigint_handler;
sigemptyset (&sigint_action.sa_mask); sigemptyset (&sigint_action.sa_mask);
@ -264,7 +264,7 @@ bool initialize(llama_context **ctx_p, llama_model **model_p, gpt_params & param
bool feed_prompt(llama_context *ctx, const gpt_params * params, llama_token * tokens, int tokens_len, int n_past) { bool feed_prompt(llama_context *ctx, const gpt_params * params, llama_token * tokens, int tokens_len, int n_past) {
console::set_display(console::prompt); console::set_display(console::prompt);
while (tokens_len > 0 && interrupted.load() == false) { while (tokens_len > 0 && !interrupted) {
const int this_chunk_size = std::min(tokens_len, params->n_batch); const int this_chunk_size = std::min(tokens_len, params->n_batch);
if (llama_eval(ctx, tokens, this_chunk_size, n_past, params->n_threads)) { if (llama_eval(ctx, tokens, this_chunk_size, n_past, params->n_threads)) {
@ -341,7 +341,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(llama_n_vocab(ctx)); candidates.reserve(llama_n_vocab(ctx));
while (n_remain > 0 && interrupted.load() == false) { while (n_remain > 0 && !interrupted) {
const llama_token id = llama_sample_token(ctx, NULL, grammar, params, last_tokens, candidates); const llama_token id = llama_sample_token(ctx, NULL, grammar, params, last_tokens, candidates);
last_tokens.push_back(id); last_tokens.push_back(id);
@ -369,6 +369,7 @@ int main(int argc, char ** argv) {
std::vector<int> output_tokens; std::vector<int> output_tokens;
std::ostringstream output_ss; std::ostringstream output_ss;
const size_t prompt_size = prompt_tokens.size(); const size_t prompt_size = prompt_tokens.size();
output_tokens.reserve(last_tokens.size() - prompt_size);
for (size_t i = 0; i < last_tokens.size(); i++) { for (size_t i = 0; i < last_tokens.size(); i++) {
const std::string token_str = llama_token_to_piece(ctx, last_tokens[i]); const std::string token_str = llama_token_to_piece(ctx, last_tokens[i]);
@ -397,5 +398,5 @@ int main(int argc, char ** argv) {
LOG_TEE("Log end\n") LOG_TEE("Log end\n")
#endif // LOG_DISABLE_LOGS #endif // LOG_DISABLE_LOGS
return interrupted.load() ? 130 : 0; return interrupted ? 130 : 0;
} }