From 90effb845f2c86353af463841c454c37619715fa Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 27 Jan 2025 22:46:17 +0000 Subject: [PATCH] Pass grammar laziness all the way down to sampler (need to print special trigger tokens e.g. for Nemo even w/ tool_choice=required) --- common/chat-handler.cpp | 66 ++++++++++------------ common/chat-handler.hpp | 1 + common/common.h | 5 +- common/sampling.cpp | 1 + examples/gbnf-validator/gbnf-validator.cpp | 2 +- examples/server/server.cpp | 4 ++ include/llama.h | 1 + src/llama-grammar.cpp | 10 +++- src/llama-grammar.h | 10 +++- src/llama-sampling.cpp | 7 ++- tests/test-chat-handler.cpp | 2 +- tests/test-grammar-integration.cpp | 2 +- 12 files changed, 62 insertions(+), 49 deletions(-) diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index 8ea031bd5..19b11d689 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -279,6 +279,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem } : tool_call; + data.grammar_lazy = false; data.grammar = build_grammar([&](const common_grammar_builder & builder) { builder.add_schema("root", schema); }, grammar_options); @@ -319,6 +320,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { fprintf(stderr, "[%s]\n", __func__); common_chat_data data; + data.grammar_lazy = params.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(params.tools, [&](const json & tool) { @@ -352,9 +354,7 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha } builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); }, grammar_options); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true}); - } + data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true}); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.format = "mistral nemo tool calls"; data.parser = std::make_unique([](const std::string & input) { @@ -369,6 +369,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c auto builtin_tools = json {"wolfram_alpha", "brave_search"}; common_chat_data data; + data.grammar_lazy = params.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; @@ -385,14 +386,10 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + builder.add_schema(name + "-args", parameters) + " \"}\"")); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); - } + data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); }); tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*")); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); - } + data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); builder.add_rule("root", string_join(tool_rules, " | ")); }, grammar_options); data.additional_stops.push_back("<|eom_id|>"); @@ -429,6 +426,7 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_ fprintf(stderr, "[%s]\n", __func__); common_chat_data data; + data.grammar_lazy = params.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; @@ -446,9 +444,7 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_ "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + builder.add_schema(name + "-args", parameters) + " \"}\"")); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); - } + data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); }); builder.add_rule("root", string_join(tool_rules, " | ")); @@ -468,8 +464,7 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { fprintf(stderr, "[%s]\n", __func__); common_chat_data data; - data.grammar = "root ::= .*"; - // data.grammar = "root ::= .*"; + data.grammar_lazy = params.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(params.tools, [&](const json & tool) { @@ -480,9 +475,7 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat tool_rules.push_back(builder.add_rule(name + "-call", "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); }); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); - } + data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space"); }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); @@ -499,6 +492,7 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { fprintf(stderr, "[%s]\n", __func__); common_chat_data data; + data.grammar_lazy = params.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(params.tools, [&](const json & tool) { @@ -525,9 +519,7 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_ } builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); }, grammar_options); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({" functools[", /* .at_start = */ false}); - } + data.grammar_triggers.push_back({" functools[", /* .at_start = */ false}); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.format = "firefunction v2 tool calls"; data.parser = std::make_unique([](const std::string & input) { @@ -542,6 +534,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar common_chat_data data; + data.grammar_lazy = params.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector first_tool_rules; std::vector subsequent_tool_rules; @@ -552,10 +545,8 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common auto args_rule = builder.add_schema(name + "-args", parameters); first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({name, /* .at_start = */ true}); - data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false}); - } + data.grammar_triggers.push_back({name, /* .at_start = */ true}); + data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false}); }); auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; if (params.parallel_tool_calls) { @@ -591,6 +582,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons std::string python_code_argument_name; auto has_raw_python = false; + data.grammar_lazy = params.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(params.tools, [&](const json & tool) { @@ -624,15 +616,11 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons }); if (has_raw_python) { tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); - } + data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); } auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({"{"name": "foo", "arguments": {"a": 1}})* + data.grammar_lazy = params.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(params.tools, [&](const json & tool) { @@ -684,9 +673,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha }); auto tool_call = "\"\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"\" space"; builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({"", /* .at_start = */ false}); - } + data.grammar_triggers.push_back({"", /* .at_start = */ false}); }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); @@ -701,7 +688,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha std::sregex_iterator rend; std::sregex_iterator rit(input.begin(), end, start_pattern); if (rit == rend) { - return {"assistant", input, {}}; + return { + /* .role = */ "assistant", + /* .content = */ input, + /* .tool_calls = */ {}, + }; } common_chat_msg result; @@ -732,7 +723,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha } return result; } catch (const std::exception & e) { - return {"assistant", input, {}}; + return { + /* .role = */ "assistant", + /* .content = */ input, + /* .tool_calls = */ {}, + }; } }); return data; @@ -744,6 +739,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.format = "content-only"; data.parser = std::make_unique(); + data.grammar_lazy = false; if (!params.json_schema.is_null()) { if (!params.grammar.empty()) { throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); diff --git a/common/chat-handler.hpp b/common/chat-handler.hpp index 8100d1dc6..2ba85893c 100644 --- a/common/chat-handler.hpp +++ b/common/chat-handler.hpp @@ -42,6 +42,7 @@ struct common_chat_data { std::vector additional_stops; std::unique_ptr parser; std::string format; // For debugging and testing. + bool grammar_lazy = false; }; struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params); diff --git a/common/common.h b/common/common.h index e075d39dd..c32d4d067 100644 --- a/common/common.h +++ b/common/common.h @@ -160,8 +160,9 @@ struct common_params_sampling { }; std::string grammar; // optional BNF-like grammar to constrain sampling - std::vector grammar_trigger_words; // optional trigger words to enable grammar - std::vector grammar_trigger_tokens; // optional trigger tokens to enable grammar + bool grammar_lazy; + std::vector grammar_trigger_words; // optional trigger words to trigger lazy grammar + std::vector grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens. std::vector logit_bias; // logit biases to apply diff --git a/common/sampling.cpp b/common/sampling.cpp index 08ecb4599..852904552 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -159,6 +159,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co auto * result = new common_sampler { /* .params = */ params, /* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root", + params.grammar_lazy, trigger_words.data(), trigger_words.size(), params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()), /* .chain = */ llama_sampler_chain_init(lparams), diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 83cc71817..a610e6a0b 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -76,7 +76,7 @@ int main(int argc, char** argv) { grammar_str = buffer.str(); } - llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0); + llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0); if (grammar == nullptr) { fprintf(stdout, "Failed to initialize llama_grammar\n"); return 1; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a96552dff..43705a21d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3816,6 +3816,7 @@ int main(int argc, char ** argv) { task.params.oaicompat = oaicompat; task.params.oaicompat_cmpl_id = completion_id; task.params.sampling.grammar = chat_data.grammar; + task.params.sampling.grammar_lazy = chat_data.grammar_lazy; for (const auto & trigger : chat_data.grammar_triggers) { auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true); if (ids.size() == 1) { @@ -3830,6 +3831,9 @@ int main(int argc, char ** argv) { if (chat_data.parser) { task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone()); } + if (task.params.sampling.grammar_lazy) { + GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0); + } // oaicompat_model is already populated by params_from_json_cmpl tasks.push_back(task); diff --git a/include/llama.h b/include/llama.h index d2f00d23b..fc37974d3 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1198,6 +1198,7 @@ extern "C" { const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root, + bool lazy, const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 3f2ef1165..589324a85 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -964,7 +964,8 @@ struct llama_grammar * llama_grammar_init_impl( vocab, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, + /* .partial_utf8 = */ {}, + /* .lazy =*/ false, /* .awaiting_trigger = */ false, /* .trigger_buffer = */ "", /* .trigger_tokens = */ {}, @@ -976,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root, + bool lazy, const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, @@ -1069,8 +1071,9 @@ struct llama_grammar * llama_grammar_init_impl( vocab, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, - /* .awaiting_trigger = */ vec_trigger_tokens.size() > 0 || vec_trigger_words.size() > 0, + /* .partial_utf8 = */ {}, + /* .lazy = */ lazy, + /* .awaiting_trigger = */ lazy, /* .trigger_buffer = */ "", std::move(vec_trigger_tokens), std::move(vec_trigger_words), @@ -1091,6 +1094,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.rules, grammar.stacks, grammar.partial_utf8, + grammar.lazy, grammar.awaiting_trigger, grammar.trigger_buffer, grammar.trigger_tokens, diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 38e7aff96..dfd0f4764 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -116,9 +116,12 @@ struct llama_grammar { llama_partial_utf8 partial_utf8; // lazy grammars wait for trigger words or tokens before constraining the sampling. - bool awaiting_trigger; - std::string trigger_buffer; - std::vector trigger_tokens; + // we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens. + // (useful e.g. for tool_choice=required) + bool lazy; // Useful when resetting + bool awaiting_trigger; // Initialized to lazy + std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. + std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). std::vector trigger_words; }; @@ -137,6 +140,7 @@ struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root, + bool lazy, const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 82b2b474c..f9fd7441d 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1444,7 +1444,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { trigger_words.push_back(word.c_str()); } auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), - trigger_words.data(), trigger_words.size(), + ctx->grammar->lazy, trigger_words.data(), trigger_words.size(), ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); llama_grammar_free_impl(ctx->grammar); @@ -1454,7 +1454,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; - auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr, nullptr, 0, nullptr, 0); + auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0); // copy the state { @@ -1495,6 +1495,7 @@ struct llama_sampler * llama_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root, + bool lazy, const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, @@ -1506,7 +1507,7 @@ struct llama_sampler * llama_sampler_init_grammar( /* .vocab = */ vocab, /* .grammar_str = */ grammar_str, /* .grammar_root = */ grammar_root, - /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens), + /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens), }; } else { *ctx = { diff --git a/tests/test-chat-handler.cpp b/tests/test-chat-handler.cpp index 14c441fe9..f28784ccb 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat-handler.cpp @@ -39,7 +39,7 @@ static std::string read_file(const std::string &path) { } static std::unique_ptr build_grammar(const std::string & grammar_str) { - return std::unique_ptr(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0)); + return std::unique_ptr(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0)); } // TODO: extract to common helper (copied from test-grammar-integration.cpp) diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 60169dfd6..288e08f51 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -13,7 +13,7 @@ using json = nlohmann::ordered_json; static llama_grammar * build_grammar(const std::string & grammar_str) { - return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0); + return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0); } static bool test_build_grammar_fails(const std::string & grammar_str) {