diff --git a/examples/run/run.cpp b/examples/run/run.cpp index dd9ea79e8..ea467361f 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -729,11 +729,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) { // Function to tokenize the prompt static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, - std::vector & prompt_tokens) { - const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); + std::vector & prompt_tokens, const LlamaData & llama_data, + const bool is_first) { + const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); prompt_tokens.resize(n_prompt_tokens); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, - true) < 0) { + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), + llama_get_kv_cache_used_cells(llama_data.context.get()) == 0, true) < 0) { printe("failed to tokenize the prompt\n"); return -1; } @@ -774,11 +775,11 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st } // helper function to evaluate a prompt and generate a response -static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { +static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response, const bool is_first) { const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get()); std::vector tokens; - if (tokenize_prompt(vocab, prompt, tokens) < 0) { + if (tokenize_prompt(vocab, prompt, tokens, llama_data, is_first) < 0) { return 1; } @@ -852,13 +853,13 @@ static int read_user_input(std::string & user_input) { // Function to generate a response based on the prompt static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response, - const bool stdout_a_terminal) { + const bool stdout_a_terminal, const int prev_len) { // Set response color if (stdout_a_terminal) { printf("\033[33m"); } - if (generate(llama_data, prompt, response)) { + if (generate(llama_data, prompt, response, prev_len == 0)) { printe("failed to generate response\n"); return 1; } @@ -948,7 +949,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) { std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); std::string response; - if (generate_response(llama_data, prompt, response, stdout_a_terminal)) { + if (generate_response(llama_data, prompt, response, stdout_a_terminal, prev_len)) { return 1; }