ensure deepseek r1 thoughts parsed even w/o tool calls

This commit is contained in:
ochafik 2025-02-04 04:48:08 +00:00
parent b6e14a4101
commit 1f5ec59809
2 changed files with 91 additions and 49 deletions

View file

@ -565,39 +565,41 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
auto args_rule = builder.add_schema(name + "-args", parameters);
tool_rules.push_back(builder.add_rule(name + "-call",
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n"
"```json\\n\" " + args_rule + " \"```<tool▁call▁end>\""));
});
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
// so we accept common variants (then it's all constrained)
builder.add_rule("root",
"( \"<tool▁calls▁begin>\" | \"<tool_calls_begin>\" | \"<tool calls begin>\" | \"<tool\\\\_calls\\\\_begin>\" ) "
"(" +string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
"\"<tool▁calls▁end>\""
" space");
data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool_calls_begin>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool calls begin>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool\\_calls\\_begin>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool▁call▁begin>", /* .at_start = */ false});
data.preserved_tokens = {
"<think>",
"</think>",
"<tool▁sep>",
"<tool▁calls▁end",
"<tool▁call▁begin>",
"<tool▁call▁end>",
};
}, grammar_options);
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
auto args_rule = builder.add_schema(name + "-args", parameters);
tool_rules.push_back(builder.add_rule(name + "-call",
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n"
"```json\\n\" " + args_rule + " \"```<tool▁call▁end>\""));
});
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
// so we accept common variants (then it's all constrained)
builder.add_rule("root",
"( \"<tool▁calls▁begin>\" | \"<tool_calls_begin>\" | \"<tool calls begin>\" | \"<tool\\\\_calls\\\\_begin>\" ) "
"(" +string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
"\"<tool▁calls▁end>\""
" space");
data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool_calls_begin>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool calls begin>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool\\_calls\\_begin>", /* .at_start = */ false});
data.grammar_triggers.push_back({"<tool▁call▁begin>", /* .at_start = */ false});
data.preserved_tokens = {
"<think>",
"</think>",
"<tool▁sep>",
"<tool▁calls▁end",
"<tool▁call▁begin>",
"<tool▁call▁end>",
};
}, grammar_options);
}
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
// Hacks to fix the official (broken) prompt.
@ -638,7 +640,7 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input)
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
msg.tool_calls = std::move(msg2.tool_calls);
} else {
msg.content = rest;
msg.content = std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end());
}
} else {
msg.content = input;
@ -970,6 +972,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
return common_chat_params_init_firefunction_v2(tmpl, inputs);
}
if (src.find("<tool▁calls▁begin>") != std::string::npos) {
return common_chat_params_init_deepseek_r1(tmpl, inputs);
}
if (!has_tools) {
return common_chat_params_init_without_tools(tmpl, inputs);
@ -986,9 +991,6 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
}
if (src.find("<tool▁calls▁begin>") != std::string::npos) {
return common_chat_params_init_deepseek_r1(tmpl, inputs);
}
if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_params_init_mistral_nemo(tmpl, inputs);
}