This commit is contained in:
Olivier Chafik 2025-01-22 11:23:37 +00:00
parent 63387c6dca
commit a4226365bf
2 changed files with 11 additions and 15 deletions

View file

@ -1158,20 +1158,21 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
if (grammar.awaiting_trigger) { if (grammar.awaiting_trigger) {
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
grammar.awaiting_trigger = false; grammar.awaiting_trigger = false;
grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token)); llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token));
return; return;
} else { } else {
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
grammar.trigger_buffer += grammar.vocab->token_to_piece(token); grammar.trigger_buffer += grammar.vocab->token_to_piece(token);
for (const auto & word : grammar.trigger_words) { for (const auto & word : grammar.trigger_words) {
auto pos = grammar.trigger_buffer.find(word); auto pos = grammar.trigger_buffer.find(word);
if (pos == std::string::npos) { if (pos != std::string::npos) {
continue; grammar.awaiting_trigger = false;
auto constrained_str = grammar.trigger_buffer.substr(pos);
grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, constrained_str);
return;
} }
grammar.awaiting_trigger = false;
auto constrained_str = grammar.trigger_buffer.substr(pos);
llama_grammar_accept_str(grammar, constrained_str);
grammar.trigger_buffer.clear();
return;
} }
return; return;
} }

View file

@ -1443,14 +1443,9 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
for (auto & word : ctx->grammar->trigger_words) { for (auto & word : ctx->grammar->trigger_words) {
trigger_words.push_back(word.c_str()); trigger_words.push_back(word.c_str());
} }
auto * grammar_new = llama_grammar_init_impl( auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
ctx->grammar->vocab, trigger_words.data(), trigger_words.size(),
ctx->grammar_str.c_str(), ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
ctx->grammar_root.c_str(),
trigger_words.data(),
trigger_words.size(),
ctx->grammar->trigger_tokens.data(),
ctx->grammar->trigger_tokens.size());
llama_grammar_free_impl(ctx->grammar); llama_grammar_free_impl(ctx->grammar);
ctx->grammar = grammar_new; ctx->grammar = grammar_new;