Pass grammar laziness all the way down to sampler (need to print special trigger tokens e.g. for Nemo even w/ tool_choice=required)

This commit is contained in:
ochafik 2025-01-27 22:46:17 +00:00
parent ad229783c5
commit 90effb845f
12 changed files with 62 additions and 49 deletions

View file

@ -279,6 +279,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
} }
: tool_call; : tool_call;
data.grammar_lazy = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
builder.add_schema("root", schema); builder.add_schema("root", schema);
}, grammar_options); }, grammar_options);
@ -319,6 +320,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__); fprintf(stderr, "[%s]\n", __func__);
common_chat_data data; common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_function(params.tools, [&](const json & tool) { foreach_function(params.tools, [&](const json & tool) {
@ -352,9 +354,7 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
} }
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
}, grammar_options); }, grammar_options);
if (params.tool_choice != "required") { data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
}
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "mistral nemo tool calls"; data.format = "mistral nemo tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) { data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
@ -369,6 +369,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
auto builtin_tools = json {"wolfram_alpha", "brave_search"}; auto builtin_tools = json {"wolfram_alpha", "brave_search"};
common_chat_data data; common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
@ -385,14 +386,10 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) + builder.add_schema(name + "-args", parameters) +
" \"}\"")); " \"}\""));
if (params.tool_choice != "required") { data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
}
}); });
tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*")); tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") { data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
builder.add_rule("root", string_join(tool_rules, " | ")); builder.add_rule("root", string_join(tool_rules, " | "));
}, grammar_options); }, grammar_options);
data.additional_stops.push_back("<|eom_id|>"); data.additional_stops.push_back("<|eom_id|>");
@ -429,6 +426,7 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
fprintf(stderr, "[%s]\n", __func__); fprintf(stderr, "[%s]\n", __func__);
common_chat_data data; common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
@ -446,9 +444,7 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) + builder.add_schema(name + "-args", parameters) +
" \"}\"")); " \"}\""));
if (params.tool_choice != "required") { data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
}
}); });
builder.add_rule("root", string_join(tool_rules, " | ")); builder.add_rule("root", string_join(tool_rules, " | "));
@ -468,8 +464,7 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__); fprintf(stderr, "[%s]\n", __func__);
common_chat_data data; common_chat_data data;
data.grammar = "root ::= .*"; data.grammar_lazy = params.tool_choice != "required";
// data.grammar = "root ::= .*";
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_function(params.tools, [&](const json & tool) { foreach_function(params.tools, [&](const json & tool) {
@ -480,9 +475,7 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
tool_rules.push_back(builder.add_rule(name + "-call", tool_rules.push_back(builder.add_rule(name + "-call",
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n```json\\n\" " + args_rule + " \"```<tool▁call▁end>\"")); "\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n```json\\n\" " + args_rule + " \"```<tool▁call▁end>\""));
}); });
if (params.tool_choice != "required") { data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false});
}
builder.add_rule("root", "\"<tool▁calls▁begin>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space"); builder.add_rule("root", "\"<tool▁calls▁begin>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space");
}, grammar_options); }, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
@ -499,6 +492,7 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__); fprintf(stderr, "[%s]\n", __func__);
common_chat_data data; common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_function(params.tools, [&](const json & tool) { foreach_function(params.tools, [&](const json & tool) {
@ -525,9 +519,7 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
} }
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
}, grammar_options); }, grammar_options);
if (params.tool_choice != "required") { data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
}
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "firefunction v2 tool calls"; data.format = "firefunction v2 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) { data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
@ -542,6 +534,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_data data; common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules; std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules; std::vector<std::string> subsequent_tool_rules;
@ -552,10 +545,8 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
auto args_rule = builder.add_schema(name + "-args", parameters); auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
if (params.tool_choice != "required") { data.grammar_triggers.push_back({name, /* .at_start = */ true});
data.grammar_triggers.push_back({name, /* .at_start = */ true}); data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
}
}); });
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
if (params.parallel_tool_calls) { if (params.parallel_tool_calls) {
@ -591,6 +582,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
std::string python_code_argument_name; std::string python_code_argument_name;
auto has_raw_python = false; auto has_raw_python = false;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_function(params.tools, [&](const json & tool) { foreach_function(params.tools, [&](const json & tool) {
@ -624,15 +616,11 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
}); });
if (has_raw_python) { if (has_raw_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") { data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
} }
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (params.tool_choice != "required") { data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
}
}, grammar_options); }, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
@ -666,6 +654,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
fprintf(stderr, "[%s]\n", __func__); fprintf(stderr, "[%s]\n", __func__);
common_chat_data data; common_chat_data data;
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)* // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_function(params.tools, [&](const json & tool) { foreach_function(params.tools, [&](const json & tool) {
@ -684,9 +673,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
}); });
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space"; auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (params.tool_choice != "required") { data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
}
}, grammar_options); }, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
@ -701,7 +688,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
std::sregex_iterator rend; std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern); std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) { if (rit == rend) {
return {"assistant", input, {}}; return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
} }
common_chat_msg result; common_chat_msg result;
@ -732,7 +723,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
} }
return result; return result;
} catch (const std::exception & e) { } catch (const std::exception & e) {
return {"assistant", input, {}}; return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
} }
}); });
return data; return data;
@ -744,6 +739,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "content-only"; data.format = "content-only";
data.parser = std::make_unique<text_chat_parser>(); data.parser = std::make_unique<text_chat_parser>();
data.grammar_lazy = false;
if (!params.json_schema.is_null()) { if (!params.json_schema.is_null()) {
if (!params.grammar.empty()) { if (!params.grammar.empty()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");

View file

@ -42,6 +42,7 @@ struct common_chat_data {
std::vector<std::string> additional_stops; std::vector<std::string> additional_stops;
std::unique_ptr<class common_chat_parser> parser; std::unique_ptr<class common_chat_parser> parser;
std::string format; // For debugging and testing. std::string format; // For debugging and testing.
bool grammar_lazy = false;
}; };
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params); struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);

View file

@ -160,8 +160,9 @@ struct common_params_sampling {
}; };
std::string grammar; // optional BNF-like grammar to constrain sampling std::string grammar; // optional BNF-like grammar to constrain sampling
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to enable grammar bool grammar_lazy;
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to enable grammar std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
std::vector<llama_logit_bias> logit_bias; // logit biases to apply std::vector<llama_logit_bias> logit_bias; // logit biases to apply

View file

@ -159,6 +159,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
auto * result = new common_sampler { auto * result = new common_sampler {
/* .params = */ params, /* .params = */ params,
/* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root", /* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root",
params.grammar_lazy,
trigger_words.data(), trigger_words.size(), trigger_words.data(), trigger_words.size(),
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()), params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()),
/* .chain = */ llama_sampler_chain_init(lparams), /* .chain = */ llama_sampler_chain_init(lparams),

View file

@ -76,7 +76,7 @@ int main(int argc, char** argv) {
grammar_str = buffer.str(); grammar_str = buffer.str();
} }
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0); llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
if (grammar == nullptr) { if (grammar == nullptr) {
fprintf(stdout, "Failed to initialize llama_grammar\n"); fprintf(stdout, "Failed to initialize llama_grammar\n");
return 1; return 1;

View file

@ -3816,6 +3816,7 @@ int main(int argc, char ** argv) {
task.params.oaicompat = oaicompat; task.params.oaicompat = oaicompat;
task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_cmpl_id = completion_id;
task.params.sampling.grammar = chat_data.grammar; task.params.sampling.grammar = chat_data.grammar;
task.params.sampling.grammar_lazy = chat_data.grammar_lazy;
for (const auto & trigger : chat_data.grammar_triggers) { for (const auto & trigger : chat_data.grammar_triggers) {
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true); auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) { if (ids.size() == 1) {
@ -3830,6 +3831,9 @@ int main(int argc, char ** argv) {
if (chat_data.parser) { if (chat_data.parser) {
task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone()); task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone());
} }
if (task.params.sampling.grammar_lazy) {
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
}
// oaicompat_model is already populated by params_from_json_cmpl // oaicompat_model is already populated by params_from_json_cmpl
tasks.push_back(task); tasks.push_back(task);

View file

@ -1198,6 +1198,7 @@ extern "C" {
const struct llama_vocab * vocab, const struct llama_vocab * vocab,
const char * grammar_str, const char * grammar_str,
const char * grammar_root, const char * grammar_root,
bool lazy,
const char ** trigger_words, const char ** trigger_words,
size_t num_trigger_words, size_t num_trigger_words,
const llama_token * trigger_tokens, const llama_token * trigger_tokens,

View file

@ -964,7 +964,8 @@ struct llama_grammar * llama_grammar_init_impl(
vocab, vocab,
std::move(vec_rules), std::move(vec_rules),
std::move(stacks), std::move(stacks),
/* .partial_utf8 = */ {}, /* .partial_utf8 = */ {},
/* .lazy =*/ false,
/* .awaiting_trigger = */ false, /* .awaiting_trigger = */ false,
/* .trigger_buffer = */ "", /* .trigger_buffer = */ "",
/* .trigger_tokens = */ {}, /* .trigger_tokens = */ {},
@ -976,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab, const struct llama_vocab * vocab,
const char * grammar_str, const char * grammar_str,
const char * grammar_root, const char * grammar_root,
bool lazy,
const char ** trigger_words, const char ** trigger_words,
size_t num_trigger_words, size_t num_trigger_words,
const llama_token * trigger_tokens, const llama_token * trigger_tokens,
@ -1069,8 +1071,9 @@ struct llama_grammar * llama_grammar_init_impl(
vocab, vocab,
std::move(vec_rules), std::move(vec_rules),
std::move(stacks), std::move(stacks),
/* .partial_utf8 = */ {}, /* .partial_utf8 = */ {},
/* .awaiting_trigger = */ vec_trigger_tokens.size() > 0 || vec_trigger_words.size() > 0, /* .lazy = */ lazy,
/* .awaiting_trigger = */ lazy,
/* .trigger_buffer = */ "", /* .trigger_buffer = */ "",
std::move(vec_trigger_tokens), std::move(vec_trigger_tokens),
std::move(vec_trigger_words), std::move(vec_trigger_words),
@ -1091,6 +1094,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
grammar.rules, grammar.rules,
grammar.stacks, grammar.stacks,
grammar.partial_utf8, grammar.partial_utf8,
grammar.lazy,
grammar.awaiting_trigger, grammar.awaiting_trigger,
grammar.trigger_buffer, grammar.trigger_buffer,
grammar.trigger_tokens, grammar.trigger_tokens,

View file

@ -116,9 +116,12 @@ struct llama_grammar {
llama_partial_utf8 partial_utf8; llama_partial_utf8 partial_utf8;
// lazy grammars wait for trigger words or tokens before constraining the sampling. // lazy grammars wait for trigger words or tokens before constraining the sampling.
bool awaiting_trigger; // we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
std::string trigger_buffer; // (useful e.g. for tool_choice=required)
std::vector<llama_token> trigger_tokens; bool lazy; // Useful when resetting
bool awaiting_trigger; // Initialized to lazy
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
std::vector<std::string> trigger_words; std::vector<std::string> trigger_words;
}; };
@ -137,6 +140,7 @@ struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab, const struct llama_vocab * vocab,
const char * grammar_str, const char * grammar_str,
const char * grammar_root, const char * grammar_root,
bool lazy,
const char ** trigger_words, const char ** trigger_words,
size_t num_trigger_words, size_t num_trigger_words,
const llama_token * trigger_tokens, const llama_token * trigger_tokens,

View file

@ -1444,7 +1444,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
trigger_words.push_back(word.c_str()); trigger_words.push_back(word.c_str());
} }
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
trigger_words.data(), trigger_words.size(), ctx->grammar->lazy, trigger_words.data(), trigger_words.size(),
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
llama_grammar_free_impl(ctx->grammar); llama_grammar_free_impl(ctx->grammar);
@ -1454,7 +1454,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr, nullptr, 0, nullptr, 0); auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0);
// copy the state // copy the state
{ {
@ -1495,6 +1495,7 @@ struct llama_sampler * llama_sampler_init_grammar(
const struct llama_vocab * vocab, const struct llama_vocab * vocab,
const char * grammar_str, const char * grammar_str,
const char * grammar_root, const char * grammar_root,
bool lazy,
const char ** trigger_words, const char ** trigger_words,
size_t num_trigger_words, size_t num_trigger_words,
const llama_token * trigger_tokens, const llama_token * trigger_tokens,
@ -1506,7 +1507,7 @@ struct llama_sampler * llama_sampler_init_grammar(
/* .vocab = */ vocab, /* .vocab = */ vocab,
/* .grammar_str = */ grammar_str, /* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root, /* .grammar_root = */ grammar_root,
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens), /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
}; };
} else { } else {
*ctx = { *ctx = {

View file

@ -39,7 +39,7 @@ static std::string read_file(const std::string &path) {
} }
static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) { static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
return std::unique_ptr<llama_grammar>(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0)); return std::unique_ptr<llama_grammar>(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
} }
// TODO: extract to common helper (copied from test-grammar-integration.cpp) // TODO: extract to common helper (copied from test-grammar-integration.cpp)

View file

@ -13,7 +13,7 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
static llama_grammar * build_grammar(const std::string & grammar_str) { static llama_grammar * build_grammar(const std::string & grammar_str) {
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0); return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
} }
static bool test_build_grammar_fails(const std::string & grammar_str) { static bool test_build_grammar_fails(const std::string & grammar_str) {