DeepSeek-R1: implement grammar constraints
This commit is contained in:
parent
92ac336dfa
commit
118f799ae4
2 changed files with 23 additions and 3 deletions
|
@ -443,10 +443,26 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
|
|||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
data.grammar = "root ::= .*";
|
||||
// data.grammar = "root ::= .*";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
|
||||
});
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
||||
}
|
||||
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space");
|
||||
}, grammar_options);
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^<]+)\n```json\n");
|
||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||
return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
|
||||
});
|
||||
|
|
|
@ -353,6 +353,10 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
|
|||
throw std::runtime_error("Full message does not start with prefix");
|
||||
}
|
||||
|
||||
if (full == prefix) {
|
||||
throw std::runtime_error("Full message is the same as the prefix");
|
||||
}
|
||||
|
||||
auto delta = full.substr(prefix.size());
|
||||
|
||||
// Strip end tokens
|
||||
|
@ -398,7 +402,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
|
||||
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
||||
{"role", "assistant"},
|
||||
{"content", ""},
|
||||
{"content", {}},
|
||||
{"tool_calls", tool_calls}
|
||||
}, tools);
|
||||
if (!match_string(content_less_delta, grammar.get())) {
|
||||
|
@ -490,7 +494,7 @@ static void test_grammars() {
|
|||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||
test_template(tmpl, {}, tool_call_message, tools);
|
||||
test_template(tmpl, { "<|end▁of▁sentence|>" }, tool_call_message, tools);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue