minja: enhance backfill of templates w/o tools description (use example tool call delta!)
This commit is contained in:
parent
4d0598e144
commit
d3b60b8ad8
1 changed files with 49 additions and 1 deletions
|
@ -41,6 +41,7 @@ class chat_template {
|
|||
std::string bos_token_;
|
||||
std::string eos_token_;
|
||||
std::shared_ptr<minja::TemplateNode> template_root_;
|
||||
std::string tool_call_example_;
|
||||
|
||||
std::string try_raw_render(
|
||||
const nlohmann::ordered_json & messages,
|
||||
|
@ -176,6 +177,43 @@ class chat_template {
|
|||
caps_.supports_tool_responses = contains(out, "Some response!");
|
||||
caps_.supports_tool_call_id = contains(out, "call_911_");
|
||||
}
|
||||
|
||||
if (!caps_.supports_tools) {
|
||||
const json user_msg {
|
||||
{"role", "user"},
|
||||
{"content", "Hey"},
|
||||
};
|
||||
const json tool_call_msg {
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
// TODO: detect if requires numerical id or fixed length == 6 like Nemo
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "tool_name"},
|
||||
{"arguments", (json {
|
||||
{"arg1", "some_value"},
|
||||
}).dump()},
|
||||
}},
|
||||
},
|
||||
})},
|
||||
};
|
||||
const json tools;
|
||||
auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true);
|
||||
auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false);
|
||||
if (full.find(prefix) != 0) {
|
||||
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
|
||||
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
|
||||
} else {
|
||||
throw std::runtime_error("prefix not found at start of full: " + prefix + " vs " + full);
|
||||
}
|
||||
} else {
|
||||
|
||||
}
|
||||
tool_call_example_ = full.substr(prefix.size());
|
||||
}
|
||||
}
|
||||
|
||||
const std::string & source() const { return source_; }
|
||||
|
@ -229,7 +267,17 @@ class chat_template {
|
|||
};
|
||||
auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools;
|
||||
|
||||
for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) {
|
||||
json adjusted_messages;
|
||||
if (needs_tools_in_system) {
|
||||
adjusted_messages = add_system(messages,
|
||||
"\n\n"
|
||||
"You can call any of the following tools to satisfy the user's requests: " + tools.dump(2) + "\n\n"
|
||||
"Example tool call syntax:\n\n" + tool_call_example_ + "\n\n");
|
||||
} else {
|
||||
adjusted_messages = messages;
|
||||
}
|
||||
|
||||
for (const auto & message_ : adjusted_messages) {
|
||||
auto message = message_;
|
||||
if (!message.contains("role") || !message.contains("content")) {
|
||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue