This commit is contained in:
ochafik 2025-01-28 23:46:51 +00:00
parent cad1448ac7
commit 4f257550a2

View file

@ -17,19 +17,26 @@ using json = nlohmann::ordered_json;
namespace minja { namespace minja {
struct chat_template_caps {
bool supports_tools = false;
bool supports_tool_calls = false;
bool supports_tool_responses = false;
bool supports_system_role = false;
bool supports_parallel_tool_calls = false;
bool supports_tool_call_id = false;
// meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool requires_object_arguments = false;
// CohereForAI/c4ai-command-r-plus simple variant
bool requires_non_null_content = false;
// MiniMaxAI/MiniMax-Text-01 special
bool requires_typed_content = false;
};
class chat_template { class chat_template {
public:
private: private:
bool supports_tools_ = true; chat_template_caps caps_;
bool supports_tool_calls_ = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool requires_object_arguments_ = false;
bool requires_typed_content_ = false;
bool supports_system_role_ = true;
bool supports_parallel_tool_calls_ = false;
bool supports_code_interpreter_ = false;
std::string source_; std::string source_;
std::string bos_token_; std::string bos_token_;
std::string eos_token_; std::string eos_token_;
@ -43,15 +50,16 @@ class chat_template {
{ {
try { try {
auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false); auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false);
// fprintf(stderr, "Prompt: %s\n", prompt.c_str()); // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
return prompt; return prompt;
} catch (const std::exception & e) { } catch (const std::exception & e) {
// fprintf(stderr, "Error: %s\n", e.what()); // fprintf(stderr, "try_raw_render error: %s\n", e.what());
return ""; return "";
} }
} }
public: public:
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
: source_(source), bos_token_(bos_token), eos_token_(eos_token) : source_(source), bos_token_(bos_token), eos_token_(eos_token)
{ {
@ -60,82 +68,120 @@ class chat_template {
/* .lstrip_blocks = */ true, /* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false, /* .keep_trailing_newline = */ false,
}); });
supports_tool_calls_ = source.find("tool_calls") != std::string::npos;
supports_tools_ = auto contains = [](const std::string & haystack, const std::string & needle) {
try_raw_render({ return haystack.find(needle) != std::string::npos;
{{"role", "user"}, {"content", "Hey"}}, };
}, {
const std::string user_needle = "<User Needle>";
const std::string sys_needle = "<System Needle>";
const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
caps_.requires_typed_content =
!contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle)
&& contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle);
const auto dummy_user_msg = caps_.requires_typed_content
? dummy_typed_user_msg
: dummy_str_user_msg;
const json needle_system_msg = {
{"role", "system"},
{"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
};
caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
auto out = try_raw_render(json::array({
dummy_user_msg
}), json::array({
{ {
{"name", "some_tool"},
{"type", "function"}, {"type", "function"},
{"function", { {"function", {
{"name", "some_tool"}, {"name", "some_tool"},
{"parameters", {{"type", "string"}}}, {"description", "Some tool."},
{"parameters", {
{"type", "object"},
{"properties", {
{"arg", {
{"type", "string"},
{"description", "Some argument."},
}},
}},
{"required", json::array({ "arg" })},
}},
}}, }},
}, },
}, false).find("some_tool") != std::string::npos; }), false);
caps_.supports_tools = contains(out, "some_tool");
requires_object_arguments_ = auto make_tool_calls_msg = [&](const json & tool_calls) {
try_raw_render({ return json {
{
{"role", "user"},
{"content", "Hey"}
},
{
{"role", "assistant"}, {"role", "assistant"},
{"tool_calls", json::array({ {"content", nullptr},
{ {"tool_calls", tool_calls},
};
};
auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
return json {
{"id", "call_1___"}, {"id", "call_1___"},
{"type", "function"}, {"type", "function"},
{"function", { {"function", {
{"arguments", { {"arguments", arguments},
{"code", "print('Hello, World!')"}, {"name", tool_name},
}}, }},
{"name", "ipython"}, };
}}, };
}, const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
})},
// Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
}), {}, false);
auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
}), {}, false);
auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
if (caps_.supports_tool_calls) {
auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
auto tc1 = make_tool_call("test_tool1", dummy_args);
auto tc2 = make_tool_call("test_tool2", dummy_args);
auto out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({tc1, tc2})),
}), {}, false);
caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({tc1})),
{
{"role", "tool"},
{"name", "test_tool1"},
{"content", "Some response!"},
{"tool_call_id", "call_911_"},
} }
}, {}, false).find("{\"code\": \"print") != std::string::npos }), {}, false);
&& try_raw_render({ caps_.supports_tool_responses = contains(out, "Some response!");
{ caps_.supports_tool_call_id = contains(out, "call_911_");
{"role", "user"},
{"content", "Hey"}
},
{
{"role", "assistant"},
{"tool_calls", json::array({
{
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
{"name", "ipython"},
}},
},
})},
} }
}, {}, false).find("{\"code\": \"print") == std::string::npos;
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
supports_system_role_ = try_raw_render({
{{"role", "system"}, {"content", "<System Needle>"}},
{{"role", "user"}, {"content", "Hey"}}
}, {}, false).find("<System Needle>") != std::string::npos;
requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos
&& try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos;
supports_code_interpreter_ = source.find("code_interpreter") != std::string::npos;
} }
const std::string & source() const { return source_; } const std::string & source() const { return source_; }
const std::string & bos_token() const { return bos_token_; } const std::string & bos_token() const { return bos_token_; }
const std::string & eos_token() const { return eos_token_; } const std::string & eos_token() const { return eos_token_; }
bool supports_tools() const { return supports_tools_; } const chat_template_caps & original_caps() const { return caps_; }
bool supports_tool_calls() const { return supports_tool_calls_; }
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
bool requires_object_arguments() const { return requires_object_arguments_; }
std::string apply( std::string apply(
const nlohmann::ordered_json & messages, const nlohmann::ordered_json & messages,
@ -145,33 +191,20 @@ class chat_template {
bool adjust_inputs = true) const bool adjust_inputs = true) const
{ {
json actual_messages; json actual_messages;
json actual_tools;
auto has_code_interpreter = false; auto needs_adjustments = adjust_inputs && (false
for (const auto & tool : tools) { || !caps_.supports_system_role
if (tool.contains("type") && tool.at("type") == "code_interpreter") { || !caps_.supports_tools
has_code_interpreter = true; || !caps_.supports_tool_responses
break; || !caps_.supports_tool_calls
} || caps_.requires_object_arguments
} || caps_.requires_typed_content
);
if (adjust_inputs && !tools.is_null() && !supports_code_interpreter_ && has_code_interpreter) { if (needs_adjustments) {
actual_tools = json::array();
for (const auto & tool : tools) {
if (tool.contains("type") && tool.at("type") == "code_interpreter" && !supports_code_interpreter_) {
continue;
}
actual_tools.push_back(tool);
}
} else if (!tools.is_null()) {
actual_tools = tools;
}
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || !supports_tool_calls_ || requires_typed_content_)) {
actual_messages = json::array(); actual_messages = json::array();
auto add_message = [&](const json & msg) { auto add_message = [&](const json & msg) {
if (requires_typed_content_ && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { if (caps_.requires_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
actual_messages.push_back({ actual_messages.push_back({
{"role", msg.at("role")}, {"role", msg.at("role")},
{"content", {{ {"content", {{
@ -194,7 +227,7 @@ class chat_template {
pending_system.clear(); pending_system.clear();
} }
}; };
auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !supports_tools_; auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools;
for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) { for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) {
auto message = message_; auto message = message_;
@ -204,7 +237,7 @@ class chat_template {
std::string role = message.at("role"); std::string role = message.at("role");
if (message.contains("tool_calls")) { if (message.contains("tool_calls")) {
if (requires_object_arguments_ || !supports_tool_calls_) { if (caps_.requires_object_arguments || !caps_.supports_tool_calls) {
for (auto & tool_call : message.at("tool_calls")) { for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") { if (tool_call["type"] == "function") {
auto & function = tool_call.at("function"); auto & function = tool_call.at("function");
@ -219,7 +252,7 @@ class chat_template {
} }
} }
} }
if (!supports_tool_calls_) { if (!caps_.supports_tool_calls) {
auto content = message.at("content"); auto content = message.at("content");
auto tool_calls = json::array(); auto tool_calls = json::array();
for (const auto & tool_call : message.at("tool_calls")) { for (const auto & tool_call : message.at("tool_calls")) {
@ -246,7 +279,7 @@ class chat_template {
message.erase("tool_calls"); message.erase("tool_calls");
} }
} }
if (!supports_tools_ && role == "tool") { if (!caps_.supports_tool_responses && role == "tool") {
message["role"] = "user"; message["role"] = "user";
auto obj = json { auto obj = json {
{"tool_response", { {"tool_response", {
@ -261,7 +294,7 @@ class chat_template {
message.erase("name"); message.erase("name");
} }
if (!message["content"].is_null() && !supports_system_role_) { if (!message["content"].is_null() && !caps_.supports_system_role) {
std::string content = message.at("content"); std::string content = message.at("content");
if (role == "system") { if (role == "system") {
if (!pending_system.empty()) pending_system += "\n"; if (!pending_system.empty()) pending_system += "\n";
@ -280,7 +313,7 @@ class chat_template {
} }
add_message(message); add_message(message);
} }
if (!supports_system_role_) { if (!caps_.supports_system_role) {
flush_sys(); flush_sys();
} }
} else { } else {
@ -295,7 +328,7 @@ class chat_template {
})); }));
if (!tools.is_null()) { if (!tools.is_null()) {
auto tools_val = minja::Value(actual_tools); auto tools_val = minja::Value(tools);
context->set("tools", tools_val); context->set("tools", tools_val);
} }
if (!extra_context.is_null()) { if (!extra_context.is_null()) {
@ -305,7 +338,10 @@ class chat_template {
} }
} }
return template_root_->render(context); auto ret = template_root_->render(context);
// fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
// fprintf(stderr, "apply: %s\n\n", ret.c_str());
return ret;
} }
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {