fix / test parsing of r1 parser

This commit is contained in:
ochafik 2025-02-04 02:23:31 +00:00
parent 9a6847c857
commit a682d1216d
2 changed files with 33 additions and 19 deletions

View file

@ -606,8 +606,8 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
// Fix up tool call delta example added by Minja
prompt = std::regex_replace(
prompt,
std::regex("<tool▁call▁end>[\\s\\r\\n]*<User>"),
"<tool▁call▁end><tool▁calls▁end><User>");
std::regex("(<tool▁call▁end>)[\\s\\r\\n]*(<tool▁outputs▁begin>|<User>)"),
"$1<tool▁calls▁end><end▁of▁sentence>$2");
}
data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
@ -617,7 +617,7 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input)
static std::regex trigger_regex("(<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>)?");
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static std::regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
static std::regex think_regex(R"(<think>([\s\S\n]*)(</think>)?([\s\S\r\n]*))");
static std::regex think_regex("<think>([\\s\\S\\n]*?)</think>([\\s\\S\\r\\n]*)");
auto msg = parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
std::smatch match;
if (std::regex_match(msg.content, match, think_regex)) {