DeepSeek-R1: implement grammar constraints

This commit is contained in:
ochafik 2025-01-27 17:51:13 +00:00
parent 92ac336dfa
commit 118f799ae4
2 changed files with 23 additions and 3 deletions

View file

@ -443,10 +443,26 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.grammar = "root ::= .*";
// data.grammar = "root ::= .*";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
auto args_rule = builder.add_schema(name + "-args", parameters);
tool_rules.push_back(builder.add_rule(name + "-call",
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n```json\\n\" " + args_rule + " \"```<tool▁call▁end>\""));
});
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false});
}
builder.add_rule("root", "\"<tool▁calls▁begin>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space");
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
static std::regex trigger_regex("<tool▁calls▁begin>");
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^<]+)\n```json\n");
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static std::regex close_regex("```<tool▁call▁end>");
return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
});

View file

@ -353,6 +353,10 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
throw std::runtime_error("Full message does not start with prefix");
}
if (full == prefix) {
throw std::runtime_error("Full message is the same as the prefix");
}
auto delta = full.substr(prefix.size());
// Strip end tokens
@ -398,7 +402,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
{"role", "assistant"},
{"content", ""},
{"content", {}},
{"tool_calls", tool_calls}
}, tools);
if (!match_string(content_less_delta, grammar.get())) {
@ -490,7 +494,7 @@ static void test_grammars() {
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
test_template(tmpl, {}, tool_call_message, tools);
test_template(tmpl, { "<end▁of▁sentence>" }, tool_call_message, tools);
}
}