Prepare DeepSeek-R1-Distill-Llama-8B support
This commit is contained in:
parent
09971e626c
commit
92ac336dfa
5 changed files with 105 additions and 64 deletions
|
@ -6,8 +6,8 @@
|
||||||
|
|
||||||
const common_grammar_options grammar_options {
|
const common_grammar_options grammar_options {
|
||||||
/* .dotall = */ false,
|
/* .dotall = */ false,
|
||||||
// /* .compact_spaces = */ false,
|
/* .compact_spaces = */ false,
|
||||||
/* .compact_spaces = */ true,
|
// /* .compact_spaces = */ true,
|
||||||
};
|
};
|
||||||
|
|
||||||
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
||||||
|
@ -59,13 +59,11 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
||||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||||
* Aggregates the prefix, suffix and in-between text into the content.
|
* Aggregates the prefix, suffix and in-between text into the content.
|
||||||
*/
|
*/
|
||||||
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false) {
|
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::optional<std::regex> & trigger_opt, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false) {
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
|
|
||||||
common_chat_msg result;
|
common_chat_msg result;
|
||||||
result.role = "assistant";
|
result.role = "assistant";
|
||||||
auto end = input.end();
|
|
||||||
auto it = input.begin();
|
|
||||||
|
|
||||||
std::vector<std::string> tool_names;
|
std::vector<std::string> tool_names;
|
||||||
if (check_names) {
|
if (check_names) {
|
||||||
|
@ -77,6 +75,18 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto end = input.end();
|
||||||
|
auto it = input.begin();
|
||||||
|
|
||||||
|
if (trigger_opt) {
|
||||||
|
if (!std::regex_search(it, end, match, *trigger_opt)) {
|
||||||
|
result.content = input;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
result.content = match.prefix().str();
|
||||||
|
it = match.suffix().first;
|
||||||
|
}
|
||||||
|
|
||||||
while (it != end) {
|
while (it != end) {
|
||||||
std::sregex_iterator rend;
|
std::sregex_iterator rend;
|
||||||
std::sregex_iterator rit(it, end, function_regex);
|
std::sregex_iterator rit(it, end, function_regex);
|
||||||
|
@ -142,24 +152,6 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
|
|
||||||
json messages_with_system = messages;
|
|
||||||
|
|
||||||
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
|
|
||||||
std::string existing_system = messages_with_system.at(0).at("content");
|
|
||||||
messages_with_system[0] = json {
|
|
||||||
{"role", "system"},
|
|
||||||
{"content", existing_system + "\n" + system_prompt},
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
messages_with_system.insert(messages_with_system.begin(), json {
|
|
||||||
{"role", "system"},
|
|
||||||
{"content", system_prompt},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return messages_with_system;
|
|
||||||
}
|
|
||||||
|
|
||||||
class text_chat_parser : public common_chat_parser {
|
class text_chat_parser : public common_chat_parser {
|
||||||
public:
|
public:
|
||||||
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
|
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
|
||||||
|
@ -291,12 +283,11 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
||||||
builder.add_schema("root", schema);
|
builder.add_schema("root", schema);
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
// TODO: add schema to system prompt.
|
auto tweaked_messages = common_chat_template::add_system(
|
||||||
auto tweaked_messages = add_system(
|
|
||||||
params.messages,
|
params.messages,
|
||||||
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
||||||
|
|
||||||
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
|
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
|
||||||
json data = json::parse(input);
|
json data = json::parse(input);
|
||||||
common_chat_msg result;
|
common_chat_msg result;
|
||||||
|
@ -363,7 +354,7 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
|
||||||
if (params.tool_choice != "required") {
|
if (params.tool_choice != "required") {
|
||||||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||||
}
|
}
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
||||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||||
});
|
});
|
||||||
|
@ -396,14 +387,13 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
||||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.additional_stops.push_back("<|eom_id|>");
|
data.additional_stops.push_back("<|eom_id|>");
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {
|
||||||
{"builtin_tools", builtin_tools},
|
{"builtin_tools", builtin_tools},
|
||||||
});
|
});
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||||
static std::regex function_regex("<\\|python_tag\\|>\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
static std::regex function_regex("<\\|python_tag\\|>\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||||
static std::regex close_regex("\\}");
|
static std::regex close_regex("\\}");
|
||||||
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
||||||
return res;
|
|
||||||
});
|
});
|
||||||
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
|
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
|
||||||
return data;
|
return data;
|
||||||
|
@ -438,17 +428,31 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
|
||||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.additional_stops.push_back("<|eom_id|>");
|
data.additional_stops.push_back("<|eom_id|>");
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {});
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {});
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||||
static std::regex close_regex("\\}");
|
static std::regex close_regex("\\}");
|
||||||
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
||||||
return res;
|
return res;
|
||||||
});
|
});
|
||||||
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
|
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
|
common_chat_data data;
|
||||||
|
data.grammar = "root ::= .*";
|
||||||
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||||
|
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||||
|
static std::regex trigger_regex("<|tool▁calls▁begin|>");
|
||||||
|
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^<]+)\n```json\n");
|
||||||
|
static std::regex close_regex("```<|tool▁call▁end|>");
|
||||||
|
return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
|
||||||
|
});
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
@ -481,7 +485,7 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
|
||||||
if (params.tool_choice != "required") {
|
if (params.tool_choice != "required") {
|
||||||
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||||
}
|
}
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
||||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||||
});
|
});
|
||||||
|
@ -519,12 +523,12 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
|
||||||
|
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||||
static std::regex close_regex(R"($|(?=>>>))");
|
static std::regex close_regex(R"($|(?=>>>))");
|
||||||
|
|
||||||
auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
|
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true);
|
||||||
if (res.content.find("all\n") == 0) {
|
if (res.content.find("all\n") == 0) {
|
||||||
res.content = res.content.substr(4);
|
res.content = res.content.substr(4);
|
||||||
}
|
}
|
||||||
|
@ -587,7 +591,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
|
||||||
}
|
}
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
|
||||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||||
|
@ -608,7 +612,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
|
||||||
}
|
}
|
||||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||||
static std::regex close_regex(R"(</function>)");
|
static std::regex close_regex(R"(</function>)");
|
||||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, has_raw_python);
|
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
|
||||||
});
|
});
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
@ -640,7 +644,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
||||||
}
|
}
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
|
||||||
try {
|
try {
|
||||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||||
|
@ -691,7 +695,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
||||||
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||||
data.parser = std::make_unique<text_chat_parser>();
|
data.parser = std::make_unique<text_chat_parser>();
|
||||||
if (!params.json_schema.is_null()) {
|
if (!params.json_schema.is_null()) {
|
||||||
if (!params.grammar.empty()) {
|
if (!params.grammar.empty()) {
|
||||||
|
@ -733,6 +737,9 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
|
||||||
return common_chat_init_llama_3_2_tool_calls(tmpl, params);
|
return common_chat_init_llama_3_2_tool_calls(tmpl, params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
||||||
|
return common_chat_init_deepseek_r1_tool_call(tmpl, params);
|
||||||
|
}
|
||||||
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||||
// TODO: Command-R-Plus
|
// TODO: Command-R-Plus
|
||||||
// }
|
// }
|
||||||
|
|
|
@ -23,6 +23,7 @@ struct common_chat_params {
|
||||||
bool parallel_tool_calls;
|
bool parallel_tool_calls;
|
||||||
bool stream;
|
bool stream;
|
||||||
std::string grammar;
|
std::string grammar;
|
||||||
|
bool add_generation_prompt = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
class common_chat_parser {
|
class common_chat_parser {
|
||||||
|
|
|
@ -22,6 +22,7 @@ class chat_template {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool supports_tools_ = true;
|
bool supports_tools_ = true;
|
||||||
|
bool supports_tool_calls_ = true;
|
||||||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||||
bool requires_object_arguments_ = false;
|
bool requires_object_arguments_ = false;
|
||||||
|
@ -59,7 +60,13 @@ class chat_template {
|
||||||
/* .lstrip_blocks = */ true,
|
/* .lstrip_blocks = */ true,
|
||||||
/* .keep_trailing_newline = */ false,
|
/* .keep_trailing_newline = */ false,
|
||||||
});
|
});
|
||||||
supports_tools_ = source.find("tools") != std::string::npos;
|
supports_tool_calls_ = source.find("tool_calls") != std::string::npos;
|
||||||
|
supports_tools_ =
|
||||||
|
try_raw_render({
|
||||||
|
{{"role", "user"}, {"content", "Hey"}},
|
||||||
|
}, {
|
||||||
|
{{"name", "some_tool"}, {"parameters", {{"type", "string"}}}},
|
||||||
|
}, false).find("some_tool") != std::string::npos;
|
||||||
|
|
||||||
requires_object_arguments_ =
|
requires_object_arguments_ =
|
||||||
try_raw_render({
|
try_raw_render({
|
||||||
|
@ -120,6 +127,7 @@ class chat_template {
|
||||||
const std::string & bos_token() const { return bos_token_; }
|
const std::string & bos_token() const { return bos_token_; }
|
||||||
const std::string & eos_token() const { return eos_token_; }
|
const std::string & eos_token() const { return eos_token_; }
|
||||||
bool supports_tools() const { return supports_tools_; }
|
bool supports_tools() const { return supports_tools_; }
|
||||||
|
bool supports_tool_calls() const { return supports_tool_calls_; }
|
||||||
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
|
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
|
||||||
|
|
||||||
std::string apply(
|
std::string apply(
|
||||||
|
@ -152,7 +160,7 @@ class chat_template {
|
||||||
actual_tools = tools;
|
actual_tools = tools;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) {
|
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || !supports_tool_calls_ || requires_typed_content_)) {
|
||||||
actual_messages = json::array();
|
actual_messages = json::array();
|
||||||
|
|
||||||
auto add_message = [&](const json & msg) {
|
auto add_message = [&](const json & msg) {
|
||||||
|
@ -179,7 +187,9 @@ class chat_template {
|
||||||
pending_system.clear();
|
pending_system.clear();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
for (const auto & message_ : messages) {
|
auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !supports_tools_;
|
||||||
|
|
||||||
|
for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) {
|
||||||
auto message = message_;
|
auto message = message_;
|
||||||
if (!message.contains("role") || !message.contains("content")) {
|
if (!message.contains("role") || !message.contains("content")) {
|
||||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||||
|
@ -187,7 +197,7 @@ class chat_template {
|
||||||
std::string role = message.at("role");
|
std::string role = message.at("role");
|
||||||
|
|
||||||
if (message.contains("tool_calls")) {
|
if (message.contains("tool_calls")) {
|
||||||
if (requires_object_arguments_ || !supports_tools_) {
|
if (requires_object_arguments_ || !supports_tool_calls_) {
|
||||||
for (auto & tool_call : message.at("tool_calls")) {
|
for (auto & tool_call : message.at("tool_calls")) {
|
||||||
if (tool_call["type"] == "function") {
|
if (tool_call["type"] == "function") {
|
||||||
auto & function = tool_call.at("function");
|
auto & function = tool_call.at("function");
|
||||||
|
@ -201,7 +211,7 @@ class chat_template {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!supports_tools_) {
|
if (!supports_tool_calls_) {
|
||||||
auto content = message.at("content");
|
auto content = message.at("content");
|
||||||
auto tool_calls = json::array();
|
auto tool_calls = json::array();
|
||||||
for (const auto & tool_call : message.at("tool_calls")) {
|
for (const auto & tool_call : message.at("tool_calls")) {
|
||||||
|
@ -262,7 +272,9 @@ class chat_template {
|
||||||
}
|
}
|
||||||
add_message(message);
|
add_message(message);
|
||||||
}
|
}
|
||||||
|
if (!supports_system_role_) {
|
||||||
flush_sys();
|
flush_sys();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
actual_messages = messages;
|
actual_messages = messages;
|
||||||
}
|
}
|
||||||
|
@ -287,6 +299,24 @@ class chat_template {
|
||||||
|
|
||||||
return template_root_->render(context);
|
return template_root_->render(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
|
||||||
|
json messages_with_system = messages;
|
||||||
|
|
||||||
|
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
|
||||||
|
std::string existing_system = messages_with_system.at(0).at("content");
|
||||||
|
messages_with_system[0] = json {
|
||||||
|
{"role", "system"},
|
||||||
|
{"content", existing_system + "\n" + system_prompt},
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
messages_with_system.insert(messages_with_system.begin(), json {
|
||||||
|
{"role", "system"},
|
||||||
|
{"content", system_prompt},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return messages_with_system;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace minja
|
} // namespace minja
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}
|
|
@ -134,10 +134,7 @@ const auto python_tool = json::parse(R"({
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})");
|
})");
|
||||||
const auto code_interpreter_tool = json::parse(R"({
|
const json tools = {special_function_tool, python_tool};
|
||||||
"type": "code_interpreter"
|
|
||||||
})");
|
|
||||||
const json tools = {special_function_tool, code_interpreter_tool};
|
|
||||||
|
|
||||||
// static void test_parsing() {
|
// static void test_parsing() {
|
||||||
// json request = {
|
// json request = {
|
||||||
|
@ -348,6 +345,7 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
|
||||||
params.tools = tools;
|
params.tools = tools;
|
||||||
std::string prefix = common_chat_init(tmpl, params).prompt;
|
std::string prefix = common_chat_init(tmpl, params).prompt;
|
||||||
params.messages.push_back(delta_message);
|
params.messages.push_back(delta_message);
|
||||||
|
params.add_generation_prompt = false;
|
||||||
std::string full = common_chat_init(tmpl, params).prompt;
|
std::string full = common_chat_init(tmpl, params).prompt;
|
||||||
|
|
||||||
// Check full starts with prefix
|
// Check full starts with prefix
|
||||||
|
@ -412,7 +410,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
static void test_grammars() {
|
static void test_grammars() {
|
||||||
auto tool_call_message = json {
|
auto tool_call_message = json {
|
||||||
{"role", "assistant"},
|
{"role", "assistant"},
|
||||||
{"content", ""},
|
{"content", {}},
|
||||||
{"tool_calls", json {{
|
{"tool_calls", json {{
|
||||||
{"type", "function"},
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
|
@ -426,7 +424,7 @@ static void test_grammars() {
|
||||||
|
|
||||||
auto python_tool_call_message = json {
|
auto python_tool_call_message = json {
|
||||||
{"role", "assistant"},
|
{"role", "assistant"},
|
||||||
{"content", ""},
|
{"content", {}},
|
||||||
{"tool_calls", json {{
|
{"tool_calls", json {{
|
||||||
{"type", "function"},
|
{"type", "function"},
|
||||||
{"function", {
|
{"function", {
|
||||||
|
@ -442,18 +440,18 @@ static void test_grammars() {
|
||||||
}
|
}
|
||||||
// {
|
// {
|
||||||
// const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
// const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
||||||
// assert_equals(tmpl.requires_object_arguments_, true);
|
// // assert_equals(tmpl.requires_object_arguments_, true);
|
||||||
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||||
// test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
|
// test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
|
||||||
// }
|
// }
|
||||||
// {
|
{
|
||||||
// const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||||
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||||
// }
|
}
|
||||||
// {
|
{
|
||||||
// const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
|
||||||
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||||
// }
|
}
|
||||||
// {
|
// {
|
||||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
@ -462,10 +460,10 @@ static void test_grammars() {
|
||||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
|
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
|
||||||
// }
|
// }
|
||||||
// {
|
{
|
||||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
// }
|
}
|
||||||
// {
|
// {
|
||||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
||||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
@ -490,6 +488,10 @@ static void test_grammars() {
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
|
||||||
test_template(tmpl, { "<|end|>" }, tool_call_message_with_id, tools);
|
test_template(tmpl, { "<|end|>" }, tool_call_message_with_id, tools);
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||||
|
test_template(tmpl, {}, tool_call_message, tools);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue