minja: enhance backfill of templates w/o tools description (use example tool call delta!)

This commit is contained in:
ochafik 2025-02-03 01:03:04 +00:00
parent 4d0598e144
commit d3b60b8ad8

View file

@ -41,6 +41,7 @@ class chat_template {
std::string bos_token_;
std::string eos_token_;
std::shared_ptr<minja::TemplateNode> template_root_;
std::string tool_call_example_;
std::string try_raw_render(
const nlohmann::ordered_json & messages,
@ -176,6 +177,43 @@ class chat_template {
caps_.supports_tool_responses = contains(out, "Some response!");
caps_.supports_tool_call_id = contains(out, "call_911_");
}
if (!caps_.supports_tools) {
const json user_msg {
{"role", "user"},
{"content", "Hey"},
};
const json tool_call_msg {
{"role", "assistant"},
{"content", nullptr},
{"tool_calls", json::array({
{
// TODO: detect if requires numerical id or fixed length == 6 like Nemo
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"name", "tool_name"},
{"arguments", (json {
{"arg1", "some_value"},
}).dump()},
}},
},
})},
};
const json tools;
auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true);
auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false);
if (full.find(prefix) != 0) {
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
} else {
throw std::runtime_error("prefix not found at start of full: " + prefix + " vs " + full);
}
} else {
}
tool_call_example_ = full.substr(prefix.size());
}
}
const std::string & source() const { return source_; }
@ -229,7 +267,17 @@ class chat_template {
};
auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools;
for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) {
json adjusted_messages;
if (needs_tools_in_system) {
adjusted_messages = add_system(messages,
"\n\n"
"You can call any of the following tools to satisfy the user's requests: " + tools.dump(2) + "\n\n"
"Example tool call syntax:\n\n" + tool_call_example_ + "\n\n");
} else {
adjusted_messages = messages;
}
for (const auto & message_ : adjusted_messages) {
auto message = message_;
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());