From 60325ec78e381790ba31ec5289855d6a20b5a694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?DAN=E2=84=A2?= Date: Sun, 3 Mar 2024 08:46:43 -0500 Subject: [PATCH] Tokenize antiprompts only once. --- examples/main/main.cpp | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 861a88d58..0a5bbbbd1 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -760,8 +760,27 @@ int main(int argc, char ** argv) { ? last_output.length() - static_cast(antiprompt.length() + extra_padding) : 0; - auto tmp = ::llama_tokenize(ctx, antiprompt, false, true); - if (last_output.find(antiprompt, search_start_pos) != std::string::npos || (tmp.size() == 1 && llama_sampling_last(ctx_sampling) == tmp[0])) { + if (last_output.find(antiprompt, search_start_pos) != std::string::npos) { + if (params.interactive) { + is_interacting = true; + } + is_antiprompt = true; + break; + } + } + + // tokenize reverse/antiprompt special tokens only once using static + static std::vector> antiprompt_ids; + if (antiprompt_ids.empty()) { + for (std::string& antiprompt : params.antiprompt) { + antiprompt_ids.push_back(::llama_tokenize(ctx, antiprompt, false, true)); + } + } + + // check for reverse prompt using special tokens + llama_token last_token = llama_sampling_last(ctx_sampling); + for (std::vector ids : antiprompt_ids) { + if (ids.size() == 1 && last_token == ids[0]) { if (params.interactive) { is_interacting = true; }