From 130ca222c9ecdcf2c68cce39b4814ac5a8d4b7ba Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 3 Feb 2025 01:19:15 +0000 Subject: [PATCH] DeepSeek R1: parse thoughts / return in separate field in API (non streamed mode) --- common/chat.cpp | 41 +++++++++++++++++++++++++++++++++++--- common/common.h | 1 + examples/server/server.cpp | 3 +++ tests/test-chat.cpp | 11 +++++----- 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 63cc8ae17..51053eab9 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -22,6 +22,18 @@ std::string common_chat_format_name(common_chat_format format) { } } +static std::string string_trim(const std::string & s) { + size_t start = 0; + while (start < s.size() && std::isspace(s[start])) { + start++; + } + size_t end = s.size(); + while (end > start && std::isspace(s[end - 1])) { + end--; + } + return s.substr(start, end - start); +} + const common_grammar_options grammar_options { /* .dotall = */ false, /* .compact_spaces = */ false, @@ -537,20 +549,43 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ }); data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); data.preserved_tokens = { + "", + "", "<|tool▁sep|>", "<|tool▁call▁end|>", }; builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space"); }, grammar_options); - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + /* + Note: we do not feed the thoughts back to the template for a few reasons: + - the template doesn't use them explicitly + - if content isn't null, tool calls arent rendered + - not having the thoughts will locally reset the KV cache (losing the hot tokens of the tool calls) but will save up a lot long term. + */ + auto prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + std::string suffix = "<|Assistant|>"; + if (vocab && !llama_vocab_get_add_eos(vocab) && + inputs.add_generation_prompt && + !string_ends_with(prompt, suffix)) + { + prompt += "<|end▁of▁sentence|>"; + } + data.prompt = prompt; data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; return data; } static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) { static std::regex trigger_regex("<|tool▁calls▁begin|>"); - static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); + static std::regex function_regex(R"(<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n)"); static std::regex close_regex("```<|tool▁call▁end|>"); - return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex); + static std::regex think_regex(R"(([\s\S\n]*)([\s\S\r\n]*))"); + auto msg = parse_json_tool_calls(input, trigger_regex, function_regex, close_regex); + std::smatch match; + if (std::regex_match(msg.content, match, think_regex)) { + msg.thoughts = string_trim(match[1].str()); + msg.content = string_trim(match[2].str()); + } + return msg; } static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { diff --git a/common/common.h b/common/common.h index b208d0c7e..858d2807e 100644 --- a/common/common.h +++ b/common/common.h @@ -623,6 +623,7 @@ struct common_chat_msg { std::string role; std::string content; std::vector tool_calls; + std::string thoughts = ""; std::string tool_plan = ""; }; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 4743f2a25..864184ba0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -744,6 +744,9 @@ struct server_task_result_cmpl_final : server_task_result { {"tool_calls", tool_calls}, {"role", "assistant"}, }; + if (!msg.thoughts.empty()) { + message["thoughts"] = msg.thoughts; + } if (!msg.tool_plan.empty()) { message["tool_plan"] = msg.tool_plan; } diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 9956c1f1f..a130d6c6c 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -18,18 +18,17 @@ using json = nlohmann::ordered_json; static common_chat_msg msg_from_json(const json & message) { - common_chat_msg ret{ - "assistant", - "", - {}, - /* .tool_plan = */ "", - }; + common_chat_msg ret; + ret.role = "assistant"; if (message.contains("content") && !message.at("content").is_null()) { ret.content = message.at("content"); } if (message.contains("tool_plan")) { ret.tool_plan = message.at("tool_plan"); } + if (message.contains("thoughts")) { + ret.thoughts = message.at("thoughts"); + } auto has_tool_calls = message.contains("tool_calls"); if (has_tool_calls) { for (const auto & tc : message.at("tool_calls")) {