DeepSeek R1: parse thoughts / return in separate field in API (non streamed mode)

This commit is contained in:
ochafik 2025-02-03 01:19:15 +00:00
parent 87de852b7f
commit 130ca222c9
4 changed files with 47 additions and 9 deletions

View file

@ -22,6 +22,18 @@ std::string common_chat_format_name(common_chat_format format) {
}
}
static std::string string_trim(const std::string & s) {
size_t start = 0;
while (start < s.size() && std::isspace(s[start])) {
start++;
}
size_t end = s.size();
while (end > start && std::isspace(s[end - 1])) {
end--;
}
return s.substr(start, end - start);
}
const common_grammar_options grammar_options {
/* .dotall = */ false,
/* .compact_spaces = */ false,
@ -537,20 +549,43 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
});
data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false});
data.preserved_tokens = {
"<think>",
"</think>",
"<tool▁sep>",
"<tool▁call▁end>",
};
builder.add_rule("root", "\"<tool▁calls▁begin>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
}, grammar_options);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
/*
Note: we do not feed the thoughts back to the template for a few reasons:
- the template doesn't use them explicitly
- if content isn't null, tool calls arent rendered
- not having the thoughts will locally reset the KV cache (losing the hot tokens of the tool calls) but will save up a lot long term.
*/
auto prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
std::string suffix = "<Assistant>";
if (vocab && !llama_vocab_get_add_eos(vocab) &&
inputs.add_generation_prompt &&
!string_ends_with(prompt, suffix))
{
prompt += "<end▁of▁sentence>";
}
data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
return data;
}
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
static std::regex trigger_regex("<tool▁calls▁begin>");
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static std::regex function_regex(R"(<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n)");
static std::regex close_regex("```<tool▁call▁end>");
return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
static std::regex think_regex(R"(<think>([\s\S\n]*)</think>([\s\S\r\n]*))");
auto msg = parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
std::smatch match;
if (std::regex_match(msg.content, match, think_regex)) {
msg.thoughts = string_trim(match[1].str());
msg.content = string_trim(match[2].str());
}
return msg;
}
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {