From 18a11f43f08f191e80b086cfc7d1cc25a70a61e4 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 01:12:44 +0000 Subject: [PATCH] tool-call: r1: fix grammar --- common/chat.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 655cb9900..f4ac9fd2d 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -572,7 +572,7 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, // so we accept common variants (then it's all constrained) builder.add_rule("root", - "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\_calls\\_begin|>\" ) " + "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) " "(" +string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " "\"<|tool▁calls▁end|>\"" " space"); @@ -580,6 +580,7 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false}); data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false}); data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({"<|tool▁call▁begin|>", /* .at_start = */ false}); data.preserved_tokens = { "", "", @@ -613,9 +614,9 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ return data; } static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) { - static std::regex trigger_regex("<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>"); + static std::regex trigger_regex("(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)?"); static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); - static std::regex close_regex("```<|tool▁call▁end|>"); + static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); static std::regex think_regex(R"(([\s\S\n]*)()?([\s\S\r\n]*))"); auto msg = parse_json_tool_calls(input, trigger_regex, function_regex, close_regex); std::smatch match;