diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 3dc593a48..2c1ae0975 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1158,20 +1158,21 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token if (grammar.awaiting_trigger) { if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { grammar.awaiting_trigger = false; + grammar.trigger_buffer.clear(); llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token)); return; } else { + // TODO: consider a smarter incremental substring search algorithm (store last position to search from). grammar.trigger_buffer += grammar.vocab->token_to_piece(token); for (const auto & word : grammar.trigger_words) { auto pos = grammar.trigger_buffer.find(word); - if (pos == std::string::npos) { - continue; + if (pos != std::string::npos) { + grammar.awaiting_trigger = false; + auto constrained_str = grammar.trigger_buffer.substr(pos); + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, constrained_str); + return; } - grammar.awaiting_trigger = false; - auto constrained_str = grammar.trigger_buffer.substr(pos); - llama_grammar_accept_str(grammar, constrained_str); - grammar.trigger_buffer.clear(); - return; } return; } diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 0041a67e3..82b2b474c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1443,14 +1443,9 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { for (auto & word : ctx->grammar->trigger_words) { trigger_words.push_back(word.c_str()); } - auto * grammar_new = llama_grammar_init_impl( - ctx->grammar->vocab, - ctx->grammar_str.c_str(), - ctx->grammar_root.c_str(), - trigger_words.data(), - trigger_words.size(), - ctx->grammar->trigger_tokens.data(), - ctx->grammar->trigger_tokens.size()); + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), + trigger_words.data(), trigger_words.size(), + ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); llama_grammar_free_impl(ctx->grammar); ctx->grammar = grammar_new;