antiprompts: ensure partial match is at end of string (or else server stops sending replies)

This commit is contained in:
Olivier Chafik 2024-10-03 19:10:21 +01:00
parent fa8df0c350
commit ece12b074f
2 changed files with 26 additions and 2 deletions

View file

@ -714,8 +714,9 @@ private:
MatchResult findFirstMatch(const std::string& text, size_t offset = 0) {
TrieNode* current = &root;
MatchResult partialMatch{std::string::npos, "", true, 0, false};
auto text_length = text.length();
for (size_t i = offset; i < text.length(); ++i) {
for (size_t i = offset; i < text_length; ++i) {
char c = text[i];
while (current != &root && current->children.find(c) == current->children.end()) {
current = current->fail;
@ -745,7 +746,9 @@ private:
// If we've found a partial match and haven't returned a full match, return the partial match
if (partialMatch.pos != std::string::npos) {
return partialMatch;
if (partialMatch.pos + partialMatch.matchLength == text_length) {
return partialMatch;
}
}
return {std::string::npos, "", false, 0, false};

View file

@ -60,6 +60,27 @@ int main()
/* .matchLength = */ 3,
/* .is_grammar_trigger = */ false,
});
assert_equal(antiprompts.findFirstMatch(" ab c", 0), {
/* .pos = */ std::string::npos,
/* .pattern = */ "",
/* .is_partial = */ false,
/* .matchLength = */ 0,
/* .is_grammar_trigger = */ false,
});
assert_equal(antiprompts.findFirstMatch(" abc abc", 0), {
/* .pos = */ 1,
/* .pattern = */ "abc",
/* .is_partial = */ false,
/* .matchLength = */ 3,
/* .is_grammar_trigger = */ false,
});
assert_equal(antiprompts.findFirstMatch(" ab abc", 0), {
/* .pos = */ 4,
/* .pattern = */ "abc",
/* .is_partial = */ false,
/* .matchLength = */ 3,
/* .is_grammar_trigger = */ false,
});
assert_equal(antiprompts.findFirstMatch(" bc", 0), {
/* .pos = */ 1,
/* .pattern = */ "",