From 1f5ec598091f84cb99e0d40f392864ee1bc7fbd2 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 04:48:08 +0000 Subject: [PATCH] ensure deepseek r1 thoughts parsed even w/o tool calls --- common/chat.cpp | 76 ++++++++++---------- examples/server/tests/unit/test_tool_call.py | 64 +++++++++++++---- 2 files changed, 91 insertions(+), 49 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index c134ae568..8ce430abc 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -565,39 +565,41 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { common_chat_params data; - data.grammar_lazy = inputs.tool_choice != "required"; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - auto args_rule = builder.add_schema(name + "-args", parameters); - tool_rules.push_back(builder.add_rule(name + "-call", - "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n" - "```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); - }); - // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, - // so we accept common variants (then it's all constrained) - builder.add_rule("root", - "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) " - "(" +string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " - "\"<|tool▁calls▁end|>\"" - " space"); - data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); - data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false}); - data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false}); - data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false}); - data.grammar_triggers.push_back({"<|tool▁call▁begin|>", /* .at_start = */ false}); - data.preserved_tokens = { - "", - "", - "<|tool▁sep|>", - "<|tool▁calls▁end|", - "<|tool▁call▁begin|>", - "<|tool▁call▁end|>", - }; - }, grammar_options); + if (!inputs.tools.is_null() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto args_rule = builder.add_schema(name + "-args", parameters); + tool_rules.push_back(builder.add_rule(name + "-call", + "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n" + "```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); + }); + // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, + // so we accept common variants (then it's all constrained) + builder.add_rule("root", + "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) " + "(" +string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " + "\"<|tool▁calls▁end|>\"" + " space"); + data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({"<|tool▁call▁begin|>", /* .at_start = */ false}); + data.preserved_tokens = { + "", + "", + "<|tool▁sep|>", + "<|tool▁calls▁end|", + "<|tool▁call▁begin|>", + "<|tool▁call▁end|>", + }; + }, grammar_options); + } auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); // Hacks to fix the official (broken) prompt. @@ -638,7 +640,7 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex); msg.tool_calls = std::move(msg2.tool_calls); } else { - msg.content = rest; + msg.content = std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end()); } } else { msg.content = input; @@ -970,6 +972,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co // Firefunction v2 requires datetime and functions in the context, even w/o tools. return common_chat_params_init_firefunction_v2(tmpl, inputs); } + if (src.find("<|tool▁calls▁begin|>") != std::string::npos) { + return common_chat_params_init_deepseek_r1(tmpl, inputs); + } if (!has_tools) { return common_chat_params_init_without_tools(tmpl, inputs); @@ -986,9 +991,6 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools); } - if (src.find("<|tool▁calls▁begin|>") != std::string::npos) { - return common_chat_params_init_deepseek_r1(tmpl, inputs); - } if (src.find("[TOOL_CALLS]") != std::string::npos) { return common_chat_params_init_mistral_nemo(tmpl, inputs); } diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index 5ae9fa261..70288dbf3 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -345,20 +345,20 @@ def test_weather_tool_call(hf_repo: str, template_override: str | Tuple[str, str @pytest.mark.slow @pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [ - (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), - (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), - (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), - (None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), - (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), - (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - ("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + ("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value) - ("[\\s\\S\\r\\n]*?\\b0\\.55644242476$", 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - ("[\\s\\S\\r\\n]*?which equals 0\\.5\\.", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - ("**Answer:** 0\\.25\\b", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), + ("[\\s\\S\\r\\n]*?\\b0\\.55644242476$", 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("[\\s\\S\\r\\n]*?which equals 0\\.5\\.", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + ("[\\s\\S\\r\\n]*?\\*\\*Answer:\\*\\* 0\\.25\\b", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), ]) def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server @@ -435,6 +435,46 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, f'Expected something like "The y coordinate is 0.56.", got {content}' +@pytest.mark.slow +@pytest.mark.parametrize("n_predict,expect_content,expect_thoughts,hf_repo,template_override", [ + (128, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (1024, "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (1024, "To find the sum of.*", "First, I need to add the tens place.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), +]) +def test_thoughts(n_predict: int, expect_content: str | None, expect_thoughts: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): + global server + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 * 2 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if isinstance(template_override, tuple): + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "user", "content": "What's the sum of 102 and 7?"}, + ] + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' + + content = choice["message"].get("content") + if expect_content is not None: + assert re.match(expect_content, content), f'Expected {expect_content}, got {content}' + + thoughts = choice["message"].get("thoughts") + if expect_thoughts is not None: + assert re.match(expect_thoughts, thoughts), f'Expected {expect_thoughts}, got {thoughts}' + + @pytest.mark.slow @pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),