Fix / test models/templates/llama-cpp-deepseek-r1.jinja

This commit is contained in:
ochafik 2025-02-04 03:09:15 +00:00
parent a682d1216d
commit f0154a6479
3 changed files with 95 additions and 47 deletions

View file

@ -614,18 +614,26 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
return data;
}
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
static std::regex trigger_regex("(<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>)?");
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static std::regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
static std::regex think_regex("<think>([\\s\\S\\n]*?)</think>([\\s\\S\\r\\n]*)");
auto msg = parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
static std::regex thoughts_regex("(?:<think>([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>)([\\s\\S\\r\\n]*?)<tool▁calls▁end>");
common_chat_msg msg;
msg.role = "assistant";
std::smatch match;
if (std::regex_match(msg.content, match, think_regex)) {
if (std::regex_match(input, match, thoughts_regex)) {
msg.thoughts = string_trim(match[1].str());
msg.content = string_trim(match[2].str());
}
if (string_trim(msg.content) == "<tool▁calls▁end>") {
msg.content = "";
auto rest = match[2].str();
if (std::regex_search(rest, match, tool_calls_regex)) {
auto tool_calls = match[1].str();
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
msg.tool_calls = std::move(msg2.tool_calls);
} else {
msg.content = rest;
}
} else {
msg.content = input;
}
return msg;
}