diff --git a/main.cpp b/main.cpp index df49839a3..7646fd42d 100644 --- a/main.cpp +++ b/main.cpp @@ -935,35 +935,37 @@ int main(int argc, char ** argv) { embd.clear(); if (embd_inp.size() <= input_consumed) { - // out of user input, sample next token - const float top_k = params.top_k; - const float top_p = params.top_p; - const float temp = params.temp; - const float repeat_penalty = params.repeat_penalty; + if (!is_interacting) { + // out of user input, sample next token + const float top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + const float repeat_penalty = params.repeat_penalty; - const int n_vocab = model.hparams.n_vocab; + const int n_vocab = model.hparams.n_vocab; - gpt_vocab::id id = 0; + gpt_vocab::id id = 0; - { - const int64_t t_start_sample_us = ggml_time_us(); + { + const int64_t t_start_sample_us = ggml_time_us(); - id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng); + id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng); - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(id); + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); - t_sample_us += ggml_time_us() - t_start_sample_us; + t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // add it to the context + embd.push_back(id); + + // echo this to console + input_noecho = false; + + // decrement remaining sampling budget + --remaining_tokens; } - - // add it to the context - embd.push_back(id); - - // echo this to console - input_noecho = false; - - // decrement remaining sampling budget - --remaining_tokens; } else { // some user input remains from prompt or interaction, forward it to processing while (embd_inp.size() > input_consumed) {