tool-call: allow special tokens that are grammar triggers

This commit is contained in:
Olivier Chafik 2025-01-25 04:51:53 +00:00
parent 46415d7a51
commit c479d39abd
2 changed files with 11 additions and 5 deletions

View file

@ -2795,6 +2795,11 @@ struct server_context {
// track if given slot can be batched with slots already in the batch // track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr; server_slot * slot_batched = nullptr;
auto accept_special_token = [&](llama_token token) {
const auto & trigger_tokens = params_base.sampling.grammar_trigger_tokens;
return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end();
};
// frist, add sampled tokens from any ongoing sequences // frist, add sampled tokens from any ongoing sequences
for (auto & slot : slots) { for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING) { if (slot.state != SLOT_STATE_GENERATING) {
@ -3158,7 +3163,7 @@ struct server_context {
completion_token_output result; completion_token_output result;
result.tok = id; result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(result.tok));
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
@ -3247,7 +3252,7 @@ struct server_context {
completion_token_output result; completion_token_output result;
result.tok = ids[i]; result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(result.tok));
result.prob = 1.0f; // set later result.prob = 1.0f; // set later
// TODO: set result.probs // TODO: set result.probs

View file

@ -1155,15 +1155,17 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
GGML_ASSERT(grammar.vocab != nullptr); GGML_ASSERT(grammar.vocab != nullptr);
const auto & piece = grammar.vocab->token_to_piece(token);
if (grammar.awaiting_trigger) { if (grammar.awaiting_trigger) {
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
grammar.awaiting_trigger = false; grammar.awaiting_trigger = false;
grammar.trigger_buffer.clear(); grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token)); llama_grammar_accept_str(grammar, piece);
return; return;
} else { } else {
// TODO: consider a smarter incremental substring search algorithm (store last position to search from). // TODO: consider a smarter incremental substring search algorithm (store last position to search from).
grammar.trigger_buffer += grammar.vocab->token_to_piece(token); grammar.trigger_buffer += piece;
for (const auto & word : grammar.trigger_words) { for (const auto & word : grammar.trigger_words) {
auto pos = grammar.trigger_buffer.find(word); auto pos = grammar.trigger_buffer.find(word);
if (pos != std::string::npos) { if (pos != std::string::npos) {
@ -1187,7 +1189,6 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
const std::string & piece = grammar.vocab->token_to_piece(token);
llama_grammar_accept_str(grammar, piece); llama_grammar_accept_str(grammar, piece);
} }