DeepSeek R1: parse thoughts / return in separate field in API (non streamed mode)
This commit is contained in:
parent
87de852b7f
commit
130ca222c9
4 changed files with 47 additions and 9 deletions
|
@ -22,6 +22,18 @@ std::string common_chat_format_name(common_chat_format format) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string string_trim(const std::string & s) {
|
||||||
|
size_t start = 0;
|
||||||
|
while (start < s.size() && std::isspace(s[start])) {
|
||||||
|
start++;
|
||||||
|
}
|
||||||
|
size_t end = s.size();
|
||||||
|
while (end > start && std::isspace(s[end - 1])) {
|
||||||
|
end--;
|
||||||
|
}
|
||||||
|
return s.substr(start, end - start);
|
||||||
|
}
|
||||||
|
|
||||||
const common_grammar_options grammar_options {
|
const common_grammar_options grammar_options {
|
||||||
/* .dotall = */ false,
|
/* .dotall = */ false,
|
||||||
/* .compact_spaces = */ false,
|
/* .compact_spaces = */ false,
|
||||||
|
@ -537,20 +549,43 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
||||||
});
|
});
|
||||||
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
||||||
data.preserved_tokens = {
|
data.preserved_tokens = {
|
||||||
|
"<think>",
|
||||||
|
"</think>",
|
||||||
"<|tool▁sep|>",
|
"<|tool▁sep|>",
|
||||||
"<|tool▁call▁end|>",
|
"<|tool▁call▁end|>",
|
||||||
};
|
};
|
||||||
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
|
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
/*
|
||||||
|
Note: we do not feed the thoughts back to the template for a few reasons:
|
||||||
|
- the template doesn't use them explicitly
|
||||||
|
- if content isn't null, tool calls arent rendered
|
||||||
|
- not having the thoughts will locally reset the KV cache (losing the hot tokens of the tool calls) but will save up a lot long term.
|
||||||
|
*/
|
||||||
|
auto prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
|
std::string suffix = "<|Assistant|>";
|
||||||
|
if (vocab && !llama_vocab_get_add_eos(vocab) &&
|
||||||
|
inputs.add_generation_prompt &&
|
||||||
|
!string_ends_with(prompt, suffix))
|
||||||
|
{
|
||||||
|
prompt += "<|end▁of▁sentence|>";
|
||||||
|
}
|
||||||
|
data.prompt = prompt;
|
||||||
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
|
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
|
||||||
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
static std::regex function_regex(R"(<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n)");
|
||||||
static std::regex close_regex("```<|tool▁call▁end|>");
|
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||||
return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
|
static std::regex think_regex(R"(<think>([\s\S\n]*)</think>([\s\S\r\n]*))");
|
||||||
|
auto msg = parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_match(msg.content, match, think_regex)) {
|
||||||
|
msg.thoughts = string_trim(match[1].str());
|
||||||
|
msg.content = string_trim(match[2].str());
|
||||||
|
}
|
||||||
|
return msg;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
|
|
|
@ -623,6 +623,7 @@ struct common_chat_msg {
|
||||||
std::string role;
|
std::string role;
|
||||||
std::string content;
|
std::string content;
|
||||||
std::vector<common_tool_call> tool_calls;
|
std::vector<common_tool_call> tool_calls;
|
||||||
|
std::string thoughts = "";
|
||||||
std::string tool_plan = "";
|
std::string tool_plan = "";
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -744,6 +744,9 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
{"tool_calls", tool_calls},
|
{"tool_calls", tool_calls},
|
||||||
{"role", "assistant"},
|
{"role", "assistant"},
|
||||||
};
|
};
|
||||||
|
if (!msg.thoughts.empty()) {
|
||||||
|
message["thoughts"] = msg.thoughts;
|
||||||
|
}
|
||||||
if (!msg.tool_plan.empty()) {
|
if (!msg.tool_plan.empty()) {
|
||||||
message["tool_plan"] = msg.tool_plan;
|
message["tool_plan"] = msg.tool_plan;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,18 +18,17 @@
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
static common_chat_msg msg_from_json(const json & message) {
|
static common_chat_msg msg_from_json(const json & message) {
|
||||||
common_chat_msg ret{
|
common_chat_msg ret;
|
||||||
"assistant",
|
ret.role = "assistant";
|
||||||
"",
|
|
||||||
{},
|
|
||||||
/* .tool_plan = */ "",
|
|
||||||
};
|
|
||||||
if (message.contains("content") && !message.at("content").is_null()) {
|
if (message.contains("content") && !message.at("content").is_null()) {
|
||||||
ret.content = message.at("content");
|
ret.content = message.at("content");
|
||||||
}
|
}
|
||||||
if (message.contains("tool_plan")) {
|
if (message.contains("tool_plan")) {
|
||||||
ret.tool_plan = message.at("tool_plan");
|
ret.tool_plan = message.at("tool_plan");
|
||||||
}
|
}
|
||||||
|
if (message.contains("thoughts")) {
|
||||||
|
ret.thoughts = message.at("thoughts");
|
||||||
|
}
|
||||||
auto has_tool_calls = message.contains("tool_calls");
|
auto has_tool_calls = message.contains("tool_calls");
|
||||||
if (has_tool_calls) {
|
if (has_tool_calls) {
|
||||||
for (const auto & tc : message.at("tool_calls")) {
|
for (const auto & tc : message.at("tool_calls")) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue