diff --git a/common/common.h b/common/common.h index 3c9cc80eb..1cb518a70 100644 --- a/common/common.h +++ b/common/common.h @@ -714,8 +714,9 @@ private: MatchResult findFirstMatch(const std::string& text, size_t offset = 0) { TrieNode* current = &root; MatchResult partialMatch{std::string::npos, "", true, 0, false}; + auto text_length = text.length(); - for (size_t i = offset; i < text.length(); ++i) { + for (size_t i = offset; i < text_length; ++i) { char c = text[i]; while (current != &root && current->children.find(c) == current->children.end()) { current = current->fail; @@ -745,7 +746,9 @@ private: // If we've found a partial match and haven't returned a full match, return the partial match if (partialMatch.pos != std::string::npos) { - return partialMatch; + if (partialMatch.pos + partialMatch.matchLength == text_length) { + return partialMatch; + } } return {std::string::npos, "", false, 0, false}; diff --git a/tests/test-antiprompts.cpp b/tests/test-antiprompts.cpp index 9f9853bad..4fa688a39 100644 --- a/tests/test-antiprompts.cpp +++ b/tests/test-antiprompts.cpp @@ -60,6 +60,27 @@ int main() /* .matchLength = */ 3, /* .is_grammar_trigger = */ false, }); + assert_equal(antiprompts.findFirstMatch(" ab c", 0), { + /* .pos = */ std::string::npos, + /* .pattern = */ "", + /* .is_partial = */ false, + /* .matchLength = */ 0, + /* .is_grammar_trigger = */ false, + }); + assert_equal(antiprompts.findFirstMatch(" abc abc", 0), { + /* .pos = */ 1, + /* .pattern = */ "abc", + /* .is_partial = */ false, + /* .matchLength = */ 3, + /* .is_grammar_trigger = */ false, + }); + assert_equal(antiprompts.findFirstMatch(" ab abc", 0), { + /* .pos = */ 4, + /* .pattern = */ "abc", + /* .is_partial = */ false, + /* .matchLength = */ 3, + /* .is_grammar_trigger = */ false, + }); assert_equal(antiprompts.findFirstMatch(" bc", 0), { /* .pos = */ 1, /* .pattern = */ "",