diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index 43f2d9e4e..feb2245d7 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -443,10 +443,26 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat fprintf(stderr, "[%s]\n", __func__); common_chat_data data; data.grammar = "root ::= .*"; + // data.grammar = "root ::= .*"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(params.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto args_rule = builder.add_schema(name + "-args", parameters); + tool_rules.push_back(builder.add_rule(name + "-call", + "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); + }); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); + } + builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space"); + }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { static std::regex trigger_regex("<|tool▁calls▁begin|>"); - static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^<]+)\n```json\n"); + static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); static std::regex close_regex("```<|tool▁call▁end|>"); return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true); }); diff --git a/tests/test-chat-handler.cpp b/tests/test-chat-handler.cpp index 2d3f986be..88366c685 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat-handler.cpp @@ -353,6 +353,10 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c throw std::runtime_error("Full message does not start with prefix"); } + if (full == prefix) { + throw std::runtime_error("Full message is the same as the prefix"); + } + auto delta = full.substr(prefix.size()); // Strip end tokens @@ -398,7 +402,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector", ""); - test_template(tmpl, {}, tool_call_message, tools); + test_template(tmpl, { "<|end▁of▁sentence|>" }, tool_call_message, tools); } }