diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 939e6c36a..a8ea4d05b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2795,6 +2795,11 @@ struct server_context { // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; + auto accept_special_token = [&](llama_token token) { + const auto & trigger_tokens = params_base.sampling.grammar_trigger_tokens; + return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end(); + }; + // frist, add sampled tokens from any ongoing sequences for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { @@ -3158,7 +3163,7 @@ struct server_context { completion_token_output result; result.tok = id; - result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.params.sampling.n_probs > 0) { @@ -3247,7 +3252,7 @@ struct server_context { completion_token_output result; result.tok = ids[i]; - result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(result.tok)); result.prob = 1.0f; // set later // TODO: set result.probs diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 2c1ae0975..501b0037b 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1155,15 +1155,17 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { GGML_ASSERT(grammar.vocab != nullptr); + const auto & piece = grammar.vocab->token_to_piece(token); + if (grammar.awaiting_trigger) { if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { grammar.awaiting_trigger = false; grammar.trigger_buffer.clear(); - llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token)); + llama_grammar_accept_str(grammar, piece); return; } else { // TODO: consider a smarter incremental substring search algorithm (store last position to search from). - grammar.trigger_buffer += grammar.vocab->token_to_piece(token); + grammar.trigger_buffer += piece; for (const auto & word : grammar.trigger_words) { auto pos = grammar.trigger_buffer.find(word); if (pos != std::string::npos) { @@ -1187,7 +1189,6 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ABORT("fatal error"); } - const std::string & piece = grammar.vocab->token_to_piece(token); llama_grammar_accept_str(grammar, piece); }