tool-call: allow special tokens that are grammar triggers

This commit is contained in:
Olivier Chafik 2025-01-25 04:51:53 +00:00
parent 46415d7a51
commit c479d39abd
2 changed files with 11 additions and 5 deletions

View file

@ -2795,6 +2795,11 @@ struct server_context {
// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;
auto accept_special_token = [&](llama_token token) {
const auto & trigger_tokens = params_base.sampling.grammar_trigger_tokens;
return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end();
};
// frist, add sampled tokens from any ongoing sequences
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING) {
@ -3158,7 +3163,7 @@ struct server_context {
completion_token_output result;
result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(result.tok));
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
if (slot.params.sampling.n_probs > 0) {
@ -3247,7 +3252,7 @@ struct server_context {
completion_token_output result;
result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(result.tok));
result.prob = 1.0f; // set later
// TODO: set result.probs

View file

@ -1155,15 +1155,17 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
GGML_ASSERT(grammar.vocab != nullptr);
const auto & piece = grammar.vocab->token_to_piece(token);
if (grammar.awaiting_trigger) {
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
grammar.awaiting_trigger = false;
grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token));
llama_grammar_accept_str(grammar, piece);
return;
} else {
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
grammar.trigger_buffer += grammar.vocab->token_to_piece(token);
grammar.trigger_buffer += piece;
for (const auto & word : grammar.trigger_words) {
auto pos = grammar.trigger_buffer.find(word);
if (pos != std::string::npos) {
@ -1187,7 +1189,6 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
GGML_ABORT("fatal error");
}
const std::string & piece = grammar.vocab->token_to_piece(token);
llama_grammar_accept_str(grammar, piece);
}