diff --git a/examples/run/run.cpp b/examples/run/run.cpp index dd9ea79e8..d04108e71 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -729,10 +729,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) { // Function to tokenize the prompt static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, - std::vector & prompt_tokens) { - const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); + std::vector & prompt_tokens, const LlamaData & llama_data) { + const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0; + + const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); prompt_tokens.resize(n_prompt_tokens); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) { printe("failed to tokenize the prompt\n"); return -1; @@ -778,7 +780,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get()); std::vector tokens; - if (tokenize_prompt(vocab, prompt, tokens) < 0) { + if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) { return 1; } diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 26422601d..212b3fd79 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -95,13 +95,15 @@ int main(int argc, char ** argv) { llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); // helper function to evaluate a prompt and generate a response - auto generate = [&](const std::string & prompt, bool is_first) { + auto generate = [&](const std::string & prompt) { std::string response; + const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0; + // tokenize the prompt const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); std::vector prompt_tokens(n_prompt_tokens); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) { + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) { GGML_ABORT("failed to tokenize the prompt\n"); } @@ -180,7 +182,7 @@ int main(int argc, char ** argv) { // generate a response printf("\033[33m"); - std::string response = generate(prompt, prev_len == 0); + std::string response = generate(prompt); printf("\n\033[0m"); // add the response to the messages