tool-calls: allow r1 output to miss <think> opening tag (since latest template update adds it)

This commit is contained in:
ochafik 2025-02-09 15:50:21 +00:00
parent e598e7aa10
commit 91542ca245
2 changed files with 5 additions and 1 deletions

View file

@ -640,7 +640,7 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) { static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n"); static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static std::regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>"); static std::regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
static std::regex reasoning_content_regex("(<think>([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)"); static std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>)([\\s\\S\\r\\n]*?)<tool▁calls▁end>"); static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>)([\\s\\S\\r\\n]*?)<tool▁calls▁end>");
common_chat_msg msg; common_chat_msg msg;
msg.role = "assistant"; msg.role = "assistant";

View file

@ -675,6 +675,10 @@ static void test_template_output_parsers() {
assert_msg_equals(msg_from_json(message_assist_thoughts), assert_msg_equals(msg_from_json(message_assist_thoughts),
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
assert_msg_equals(msg_from_json(message_assist_thoughts),
// Latest template update (ast of 20250209) adds a trailing <think>\n if add_generation_prompt is true.
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
// test_template(tmpl, end_tokens, message_assist_call, tools, // test_template(tmpl, end_tokens, message_assist_call, tools,
// "<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n" // "<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
// "```json\n" // "```json\n"