diff --git a/common/chat.cpp b/common/chat.cpp
index 63cc8ae17..51053eab9 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -22,6 +22,18 @@ std::string common_chat_format_name(common_chat_format format) {
}
}
+static std::string string_trim(const std::string & s) {
+ size_t start = 0;
+ while (start < s.size() && std::isspace(s[start])) {
+ start++;
+ }
+ size_t end = s.size();
+ while (end > start && std::isspace(s[end - 1])) {
+ end--;
+ }
+ return s.substr(start, end - start);
+}
+
const common_grammar_options grammar_options {
/* .dotall = */ false,
/* .compact_spaces = */ false,
@@ -537,20 +549,43 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
});
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
data.preserved_tokens = {
+ "",
+ "",
"<|tool▁sep|>",
"<|tool▁call▁end|>",
};
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
}, grammar_options);
- data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
+ /*
+ Note: we do not feed the thoughts back to the template for a few reasons:
+ - the template doesn't use them explicitly
+ - if content isn't null, tool calls arent rendered
+ - not having the thoughts will locally reset the KV cache (losing the hot tokens of the tool calls) but will save up a lot long term.
+ */
+ auto prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
+ std::string suffix = "<|Assistant|>";
+ if (vocab && !llama_vocab_get_add_eos(vocab) &&
+ inputs.add_generation_prompt &&
+ !string_ends_with(prompt, suffix))
+ {
+ prompt += "<|end▁of▁sentence|>";
+ }
+ data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
return data;
}
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
static std::regex trigger_regex("<|tool▁calls▁begin|>");
- static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
+ static std::regex function_regex(R"(<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n)");
static std::regex close_regex("```<|tool▁call▁end|>");
- return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
+ static std::regex think_regex(R"(([\s\S\n]*)([\s\S\r\n]*))");
+ auto msg = parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
+ std::smatch match;
+ if (std::regex_match(msg.content, match, think_regex)) {
+ msg.thoughts = string_trim(match[1].str());
+ msg.content = string_trim(match[2].str());
+ }
+ return msg;
}
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
diff --git a/common/common.h b/common/common.h
index b208d0c7e..858d2807e 100644
--- a/common/common.h
+++ b/common/common.h
@@ -623,6 +623,7 @@ struct common_chat_msg {
std::string role;
std::string content;
std::vector tool_calls;
+ std::string thoughts = "";
std::string tool_plan = "";
};
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 4743f2a25..864184ba0 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -744,6 +744,9 @@ struct server_task_result_cmpl_final : server_task_result {
{"tool_calls", tool_calls},
{"role", "assistant"},
};
+ if (!msg.thoughts.empty()) {
+ message["thoughts"] = msg.thoughts;
+ }
if (!msg.tool_plan.empty()) {
message["tool_plan"] = msg.tool_plan;
}
diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp
index 9956c1f1f..a130d6c6c 100644
--- a/tests/test-chat.cpp
+++ b/tests/test-chat.cpp
@@ -18,18 +18,17 @@
using json = nlohmann::ordered_json;
static common_chat_msg msg_from_json(const json & message) {
- common_chat_msg ret{
- "assistant",
- "",
- {},
- /* .tool_plan = */ "",
- };
+ common_chat_msg ret;
+ ret.role = "assistant";
if (message.contains("content") && !message.at("content").is_null()) {
ret.content = message.at("content");
}
if (message.contains("tool_plan")) {
ret.tool_plan = message.at("tool_plan");
}
+ if (message.contains("thoughts")) {
+ ret.thoughts = message.at("thoughts");
+ }
auto has_tool_calls = message.contains("tool_calls");
if (has_tool_calls) {
for (const auto & tc : message.at("tool_calls")) {