From a682d1216df684691f05ae4634a9f4f3f6e16d55 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 02:23:31 +0000 Subject: [PATCH] fix / test parsing of r1 parser --- common/chat.cpp | 6 +++--- tests/test-chat.cpp | 46 +++++++++++++++++++++++++++++---------------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index f4ac9fd2d..8d4331cb1 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -606,8 +606,8 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ // Fix up tool call delta example added by Minja prompt = std::regex_replace( prompt, - std::regex("<|tool▁call▁end|>[\\s\\r\\n]*<|User|>"), - "<|tool▁call▁end|><|tool▁calls▁end|><|User|>"); + std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"), + "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2"); } data.prompt = prompt; data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; @@ -617,7 +617,7 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) static std::regex trigger_regex("(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)?"); static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); - static std::regex think_regex(R"(([\s\S\n]*)()?([\s\S\r\n]*))"); + static std::regex think_regex("([\\s\\S\\n]*?)([\\s\\S\\r\\n]*)"); auto msg = parse_json_tool_calls(input, trigger_regex, function_regex, close_regex); std::smatch match; if (std::regex_match(msg.content, match, think_regex)) { diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index a130d6c6c..b0eee0a0a 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -108,6 +108,8 @@ static std::string dump(const json & j) { static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) { assert_equals(expected.role, actual.role); assert_equals(expected.content, actual.content); + assert_equals(expected.thoughts, actual.thoughts); + assert_equals(expected.tool_plan, actual.tool_plan); assert_equals(expected.tool_calls.size(), actual.tool_calls.size()); for (size_t i = 0; i < expected.tool_calls.size(); i++) { const auto & expected_tool_call = expected.tool_calls[i]; @@ -226,7 +228,8 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto */ static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, const json & test_message, const json & tools = {}, const std::string & expected_delta = "", - bool expect_grammar_triggered = true) { + bool expect_grammar_triggered = true, + bool test_grammar_if_triggered = true) { common_chat_msg expected_msg = msg_from_json(test_message); auto user_message = json{ @@ -277,7 +280,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector and others unclosed. Our logic fixes the prompt. + const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), + "", ""); + std::vector end_tokens{ "<|end▁of▁sentence|>" }; + + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); + + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, text_thoughts_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false); + assert_msg_equals(msg_from_json(text_thoughts_message), common_chat_parse("I'm thinkingHello, world!", COMMON_CHAT_FORMAT_DEEPSEEK_R1)); + test_template(tmpl, end_tokens, tool_call_message, tools, + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + "```json\n" + "{\"arg1\": 1}\n" + // Look what's not here: <|tool▁calls▁end|> (also missing the <|end▁of▁sentence|>, but that is removed lazily by the test's delta logic) + "```<|tool▁call▁end|>", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ false); + } { // Not supported yet const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "", ""); @@ -558,20 +586,6 @@ static void test_template_output_parsers() { test_template(tmpl, end_tokens, tool_call_message, tools, " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]"); } - { - const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), - "", ""); - std::vector end_tokens{ "<|end▁of▁sentence|>" }; - - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); - - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, - "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" - "```json\n" - "{\"arg1\": 1}\n" - "```<|tool▁call▁end|>"); - } } int main(int argc, char ** argv) {