tool-call: allow special tokens that are grammar triggers
This commit is contained in:
parent
46415d7a51
commit
c479d39abd
2 changed files with 11 additions and 5 deletions
|
@ -2795,6 +2795,11 @@ struct server_context {
|
|||
// track if given slot can be batched with slots already in the batch
|
||||
server_slot * slot_batched = nullptr;
|
||||
|
||||
auto accept_special_token = [&](llama_token token) {
|
||||
const auto & trigger_tokens = params_base.sampling.grammar_trigger_tokens;
|
||||
return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end();
|
||||
};
|
||||
|
||||
// frist, add sampled tokens from any ongoing sequences
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state != SLOT_STATE_GENERATING) {
|
||||
|
@ -3158,7 +3163,7 @@ struct server_context {
|
|||
|
||||
completion_token_output result;
|
||||
result.tok = id;
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(result.tok));
|
||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
|
@ -3247,7 +3252,7 @@ struct server_context {
|
|||
completion_token_output result;
|
||||
|
||||
result.tok = ids[i];
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(result.tok));
|
||||
result.prob = 1.0f; // set later
|
||||
|
||||
// TODO: set result.probs
|
||||
|
|
|
@ -1155,15 +1155,17 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
|||
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
||||
GGML_ASSERT(grammar.vocab != nullptr);
|
||||
|
||||
const auto & piece = grammar.vocab->token_to_piece(token);
|
||||
|
||||
if (grammar.awaiting_trigger) {
|
||||
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||
grammar.awaiting_trigger = false;
|
||||
grammar.trigger_buffer.clear();
|
||||
llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token));
|
||||
llama_grammar_accept_str(grammar, piece);
|
||||
return;
|
||||
} else {
|
||||
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
|
||||
grammar.trigger_buffer += grammar.vocab->token_to_piece(token);
|
||||
grammar.trigger_buffer += piece;
|
||||
for (const auto & word : grammar.trigger_words) {
|
||||
auto pos = grammar.trigger_buffer.find(word);
|
||||
if (pos != std::string::npos) {
|
||||
|
@ -1187,7 +1189,6 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
const std::string & piece = grammar.vocab->token_to_piece(token);
|
||||
llama_grammar_accept_str(grammar, piece);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue