Prepare DeepSeek-R1-Distill-Llama-8B support

This commit is contained in:
ochafik 2025-01-27 17:26:43 +00:00
parent 09971e626c
commit 92ac336dfa
5 changed files with 105 additions and 64 deletions

View file

@ -6,8 +6,8 @@
const common_grammar_options grammar_options {
/* .dotall = */ false,
// /* .compact_spaces = */ false,
/* .compact_spaces = */ true,
/* .compact_spaces = */ false,
// /* .compact_spaces = */ true,
};
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
@ -59,13 +59,11 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content.
*/
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false) {
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::optional<std::regex> & trigger_opt, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false) {
std::smatch match;
common_chat_msg result;
result.role = "assistant";
auto end = input.end();
auto it = input.begin();
std::vector<std::string> tool_names;
if (check_names) {
@ -77,6 +75,18 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
}
}
auto end = input.end();
auto it = input.begin();
if (trigger_opt) {
if (!std::regex_search(it, end, match, *trigger_opt)) {
result.content = input;
return result;
}
result.content = match.prefix().str();
it = match.suffix().first;
}
while (it != end) {
std::sregex_iterator rend;
std::sregex_iterator rit(it, end, function_regex);
@ -142,24 +152,6 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
return result;
}
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
json messages_with_system = messages;
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
std::string existing_system = messages_with_system.at(0).at("content");
messages_with_system[0] = json {
{"role", "system"},
{"content", existing_system + "\n" + system_prompt},
};
} else {
messages_with_system.insert(messages_with_system.begin(), json {
{"role", "system"},
{"content", system_prompt},
});
}
return messages_with_system;
}
class text_chat_parser : public common_chat_parser {
public:
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
@ -291,12 +283,11 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
builder.add_schema("root", schema);
}, grammar_options);
// TODO: add schema to system prompt.
auto tweaked_messages = add_system(
auto tweaked_messages = common_chat_template::add_system(
params.messages,
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
json data = json::parse(input);
common_chat_msg result;
@ -363,7 +354,7 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
}
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
});
@ -396,14 +387,13 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
builder.add_rule("root", string_join(tool_rules, " | "));
}, grammar_options);
data.additional_stops.push_back("<|eom_id|>");
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {
{"builtin_tools", builtin_tools},
});
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
static std::regex function_regex("<\\|python_tag\\|>\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex close_regex("\\}");
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
return res;
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
});
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
return data;
@ -438,17 +428,31 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
builder.add_rule("root", string_join(tool_rules, " | "));
}, grammar_options);
data.additional_stops.push_back("<|eom_id|>");
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {});
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {});
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
static std::regex close_regex("\\}");
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
return res;
});
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
return data;
}
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.grammar = "root ::= .*";
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
static std::regex trigger_regex("<tool▁calls▁begin>");
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^<]+)\n```json\n");
static std::regex close_regex("```<tool▁call▁end>");
return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
});
return data;
}
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
@ -481,7 +485,7 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
}
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
});
@ -519,12 +523,12 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))");
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
if (res.content.find("all\n") == 0) {
res.content = res.content.substr(4);
}
@ -587,7 +591,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
}
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<monolithic_chat_parser>([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
@ -608,7 +612,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
}
static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, has_raw_python);
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
});
return data;
}
@ -640,7 +644,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
}
}, grammar_options);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
try {
std::regex start_pattern(R"([\n\s]*<tool_call>)");
@ -691,7 +695,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.parser = std::make_unique<text_chat_parser>();
if (!params.json_schema.is_null()) {
if (!params.grammar.empty()) {
@ -733,6 +737,9 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
return common_chat_init_llama_3_2_tool_calls(tmpl, params);
}
}
if (src.find("<tool▁calls▁begin>") != std::string::npos) {
return common_chat_init_deepseek_r1_tool_call(tmpl, params);
}
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
// TODO: Command-R-Plus
// }

View file

@ -23,6 +23,7 @@ struct common_chat_params {
bool parallel_tool_calls;
bool stream;
std::string grammar;
bool add_generation_prompt = true;
};
class common_chat_parser {

View file

@ -22,6 +22,7 @@ class chat_template {
private:
bool supports_tools_ = true;
bool supports_tool_calls_ = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool requires_object_arguments_ = false;
@ -59,7 +60,13 @@ class chat_template {
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
supports_tools_ = source.find("tools") != std::string::npos;
supports_tool_calls_ = source.find("tool_calls") != std::string::npos;
supports_tools_ =
try_raw_render({
{{"role", "user"}, {"content", "Hey"}},
}, {
{{"name", "some_tool"}, {"parameters", {{"type", "string"}}}},
}, false).find("some_tool") != std::string::npos;
requires_object_arguments_ =
try_raw_render({
@ -120,6 +127,7 @@ class chat_template {
const std::string & bos_token() const { return bos_token_; }
const std::string & eos_token() const { return eos_token_; }
bool supports_tools() const { return supports_tools_; }
bool supports_tool_calls() const { return supports_tool_calls_; }
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
std::string apply(
@ -152,7 +160,7 @@ class chat_template {
actual_tools = tools;
}
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) {
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || !supports_tool_calls_ || requires_typed_content_)) {
actual_messages = json::array();
auto add_message = [&](const json & msg) {
@ -179,7 +187,9 @@ class chat_template {
pending_system.clear();
}
};
for (const auto & message_ : messages) {
auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !supports_tools_;
for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) {
auto message = message_;
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
@ -187,7 +197,7 @@ class chat_template {
std::string role = message.at("role");
if (message.contains("tool_calls")) {
if (requires_object_arguments_ || !supports_tools_) {
if (requires_object_arguments_ || !supports_tool_calls_) {
for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
@ -201,7 +211,7 @@ class chat_template {
}
}
}
if (!supports_tools_) {
if (!supports_tool_calls_) {
auto content = message.at("content");
auto tool_calls = json::array();
for (const auto & tool_call : message.at("tool_calls")) {
@ -262,7 +272,9 @@ class chat_template {
}
add_message(message);
}
flush_sys();
if (!supports_system_role_) {
flush_sys();
}
} else {
actual_messages = messages;
}
@ -287,6 +299,24 @@ class chat_template {
return template_root_->render(context);
}
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
json messages_with_system = messages;
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
std::string existing_system = messages_with_system.at(0).at("content");
messages_with_system[0] = json {
{"role", "system"},
{"content", existing_system + "\n" + system_prompt},
};
} else {
messages_with_system.insert(messages_with_system.begin(), json {
{"role", "system"},
{"content", system_prompt},
});
}
return messages_with_system;
}
};
} // namespace minja

View file

@ -0,0 +1 @@
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}

View file

@ -134,10 +134,7 @@ const auto python_tool = json::parse(R"({
}
}
})");
const auto code_interpreter_tool = json::parse(R"({
"type": "code_interpreter"
})");
const json tools = {special_function_tool, code_interpreter_tool};
const json tools = {special_function_tool, python_tool};
// static void test_parsing() {
// json request = {
@ -348,6 +345,7 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
params.tools = tools;
std::string prefix = common_chat_init(tmpl, params).prompt;
params.messages.push_back(delta_message);
params.add_generation_prompt = false;
std::string full = common_chat_init(tmpl, params).prompt;
// Check full starts with prefix
@ -412,7 +410,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
static void test_grammars() {
auto tool_call_message = json {
{"role", "assistant"},
{"content", ""},
{"content", {}},
{"tool_calls", json {{
{"type", "function"},
{"function", {
@ -426,7 +424,7 @@ static void test_grammars() {
auto python_tool_call_message = json {
{"role", "assistant"},
{"content", ""},
{"content", {}},
{"tool_calls", json {{
{"type", "function"},
{"function", {
@ -442,18 +440,18 @@ static void test_grammars() {
}
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
// assert_equals(tmpl.requires_object_arguments_, true);
// // assert_equals(tmpl.requires_object_arguments_, true);
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
// test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
// }
{
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
}
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
@ -462,10 +460,10 @@ static void test_grammars() {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
// }
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
}
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
@ -490,6 +488,10 @@ static void test_grammars() {
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|end|>" }, tool_call_message_with_id, tools);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
test_template(tmpl, {}, tool_call_message, tools);
}
}
int main() {