Pass grammar laziness all the way down to sampler (need to print special trigger tokens e.g. for Nemo even w/ tool_choice=required)

This commit is contained in:
ochafik 2025-01-27 22:46:17 +00:00
parent ad229783c5
commit 90effb845f
12 changed files with 62 additions and 49 deletions

View file

@ -279,6 +279,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
}
: tool_call;
data.grammar_lazy = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
builder.add_schema("root", schema);
}, grammar_options);
@ -319,6 +320,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(params.tools, [&](const json & tool) {
@ -352,9 +354,7 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
}
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
}, grammar_options);
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
}
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "mistral nemo tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
@ -369,6 +369,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
@ -385,14 +386,10 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
}
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
});
tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
builder.add_rule("root", string_join(tool_rules, " | "));
}, grammar_options);
data.additional_stops.push_back("<|eom_id|>");
@ -429,6 +426,7 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
@ -446,9 +444,7 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
}
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
});
builder.add_rule("root", string_join(tool_rules, " | "));
@ -468,8 +464,7 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.grammar = "root ::= .*";
// data.grammar = "root ::= .*";
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(params.tools, [&](const json & tool) {
@ -480,9 +475,7 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
tool_rules.push_back(builder.add_rule(name + "-call",
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n```json\\n\" " + args_rule + " \"```<tool▁call▁end>\""));
});
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false});
}
data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false});
builder.add_rule("root", "\"<tool▁calls▁begin>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space");
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
@ -499,6 +492,7 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(params.tools, [&](const json & tool) {
@ -525,9 +519,7 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
}
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
}, grammar_options);
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
}
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "firefunction v2 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
@ -542,6 +534,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules;
@ -552,10 +545,8 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({name, /* .at_start = */ true});
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
}
data.grammar_triggers.push_back({name, /* .at_start = */ true});
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
});
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
if (params.parallel_tool_calls) {
@ -591,6 +582,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
std::string python_code_argument_name;
auto has_raw_python = false;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(params.tools, [&](const json & tool) {
@ -624,15 +616,11 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
});
if (has_raw_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
}
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
@ -666,6 +654,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(params.tools, [&](const json & tool) {
@ -684,9 +673,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
});
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
}
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
@ -701,7 +688,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) {
return {"assistant", input, {}};
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
}
common_chat_msg result;
@ -732,7 +723,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
}
return result;
} catch (const std::exception & e) {
return {"assistant", input, {}};
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
}
});
return data;
@ -744,6 +739,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "content-only";
data.parser = std::make_unique<text_chat_parser>();
data.grammar_lazy = false;
if (!params.json_schema.is_null()) {
if (!params.grammar.empty()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");

View file

@ -42,6 +42,7 @@ struct common_chat_data {
std::vector<std::string> additional_stops;
std::unique_ptr<class common_chat_parser> parser;
std::string format; // For debugging and testing.
bool grammar_lazy = false;
};
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);

View file

@ -160,8 +160,9 @@ struct common_params_sampling {
};
std::string grammar; // optional BNF-like grammar to constrain sampling
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to enable grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to enable grammar
bool grammar_lazy;
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
std::vector<llama_logit_bias> logit_bias; // logit biases to apply

View file

@ -159,6 +159,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root",
params.grammar_lazy,
trigger_words.data(), trigger_words.size(),
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()),
/* .chain = */ llama_sampler_chain_init(lparams),

View file

@ -76,7 +76,7 @@ int main(int argc, char** argv) {
grammar_str = buffer.str();
}
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0);
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
if (grammar == nullptr) {
fprintf(stdout, "Failed to initialize llama_grammar\n");
return 1;

View file

@ -3816,6 +3816,7 @@ int main(int argc, char ** argv) {
task.params.oaicompat = oaicompat;
task.params.oaicompat_cmpl_id = completion_id;
task.params.sampling.grammar = chat_data.grammar;
task.params.sampling.grammar_lazy = chat_data.grammar_lazy;
for (const auto & trigger : chat_data.grammar_triggers) {
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
@ -3830,6 +3831,9 @@ int main(int argc, char ** argv) {
if (chat_data.parser) {
task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone());
}
if (task.params.sampling.grammar_lazy) {
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
}
// oaicompat_model is already populated by params_from_json_cmpl
tasks.push_back(task);

View file

@ -1198,6 +1198,7 @@ extern "C" {
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,

View file

@ -964,7 +964,8 @@ struct llama_grammar * llama_grammar_init_impl(
vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
/* .partial_utf8 = */ {},
/* .lazy =*/ false,
/* .awaiting_trigger = */ false,
/* .trigger_buffer = */ "",
/* .trigger_tokens = */ {},
@ -976,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
@ -1069,8 +1071,9 @@ struct llama_grammar * llama_grammar_init_impl(
vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
/* .awaiting_trigger = */ vec_trigger_tokens.size() > 0 || vec_trigger_words.size() > 0,
/* .partial_utf8 = */ {},
/* .lazy = */ lazy,
/* .awaiting_trigger = */ lazy,
/* .trigger_buffer = */ "",
std::move(vec_trigger_tokens),
std::move(vec_trigger_words),
@ -1091,6 +1094,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
grammar.rules,
grammar.stacks,
grammar.partial_utf8,
grammar.lazy,
grammar.awaiting_trigger,
grammar.trigger_buffer,
grammar.trigger_tokens,

View file

@ -116,9 +116,12 @@ struct llama_grammar {
llama_partial_utf8 partial_utf8;
// lazy grammars wait for trigger words or tokens before constraining the sampling.
bool awaiting_trigger;
std::string trigger_buffer;
std::vector<llama_token> trigger_tokens;
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
// (useful e.g. for tool_choice=required)
bool lazy; // Useful when resetting
bool awaiting_trigger; // Initialized to lazy
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
std::vector<std::string> trigger_words;
};
@ -137,6 +140,7 @@ struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,

View file

@ -1444,7 +1444,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
trigger_words.push_back(word.c_str());
}
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
trigger_words.data(), trigger_words.size(),
ctx->grammar->lazy, trigger_words.data(), trigger_words.size(),
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
llama_grammar_free_impl(ctx->grammar);
@ -1454,7 +1454,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr, nullptr, 0, nullptr, 0);
auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0);
// copy the state
{
@ -1495,6 +1495,7 @@ struct llama_sampler * llama_sampler_init_grammar(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
@ -1506,7 +1507,7 @@ struct llama_sampler * llama_sampler_init_grammar(
/* .vocab = */ vocab,
/* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root,
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
};
} else {
*ctx = {

View file

@ -39,7 +39,7 @@ static std::string read_file(const std::string &path) {
}
static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
return std::unique_ptr<llama_grammar>(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0));
return std::unique_ptr<llama_grammar>(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
}
// TODO: extract to common helper (copied from test-grammar-integration.cpp)

View file

@ -13,7 +13,7 @@
using json = nlohmann::ordered_json;
static llama_grammar * build_grammar(const std::string & grammar_str) {
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0);
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
}
static bool test_build_grammar_fails(const std::string & grammar_str) {