minja: enhance backfill of templates w/o tools description (use example tool call delta!)
This commit is contained in:
parent
4d0598e144
commit
d3b60b8ad8
1 changed files with 49 additions and 1 deletions
|
@ -41,6 +41,7 @@ class chat_template {
|
||||||
std::string bos_token_;
|
std::string bos_token_;
|
||||||
std::string eos_token_;
|
std::string eos_token_;
|
||||||
std::shared_ptr<minja::TemplateNode> template_root_;
|
std::shared_ptr<minja::TemplateNode> template_root_;
|
||||||
|
std::string tool_call_example_;
|
||||||
|
|
||||||
std::string try_raw_render(
|
std::string try_raw_render(
|
||||||
const nlohmann::ordered_json & messages,
|
const nlohmann::ordered_json & messages,
|
||||||
|
@ -176,6 +177,43 @@ class chat_template {
|
||||||
caps_.supports_tool_responses = contains(out, "Some response!");
|
caps_.supports_tool_responses = contains(out, "Some response!");
|
||||||
caps_.supports_tool_call_id = contains(out, "call_911_");
|
caps_.supports_tool_call_id = contains(out, "call_911_");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!caps_.supports_tools) {
|
||||||
|
const json user_msg {
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", "Hey"},
|
||||||
|
};
|
||||||
|
const json tool_call_msg {
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"content", nullptr},
|
||||||
|
{"tool_calls", json::array({
|
||||||
|
{
|
||||||
|
// TODO: detect if requires numerical id or fixed length == 6 like Nemo
|
||||||
|
{"id", "call_1___"},
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"name", "tool_name"},
|
||||||
|
{"arguments", (json {
|
||||||
|
{"arg1", "some_value"},
|
||||||
|
}).dump()},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
})},
|
||||||
|
};
|
||||||
|
const json tools;
|
||||||
|
auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true);
|
||||||
|
auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false);
|
||||||
|
if (full.find(prefix) != 0) {
|
||||||
|
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
|
||||||
|
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("prefix not found at start of full: " + prefix + " vs " + full);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
}
|
||||||
|
tool_call_example_ = full.substr(prefix.size());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string & source() const { return source_; }
|
const std::string & source() const { return source_; }
|
||||||
|
@ -229,7 +267,17 @@ class chat_template {
|
||||||
};
|
};
|
||||||
auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools;
|
auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools;
|
||||||
|
|
||||||
for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) {
|
json adjusted_messages;
|
||||||
|
if (needs_tools_in_system) {
|
||||||
|
adjusted_messages = add_system(messages,
|
||||||
|
"\n\n"
|
||||||
|
"You can call any of the following tools to satisfy the user's requests: " + tools.dump(2) + "\n\n"
|
||||||
|
"Example tool call syntax:\n\n" + tool_call_example_ + "\n\n");
|
||||||
|
} else {
|
||||||
|
adjusted_messages = messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & message_ : adjusted_messages) {
|
||||||
auto message = message_;
|
auto message = message_;
|
||||||
if (!message.contains("role") || !message.contains("content")) {
|
if (!message.contains("role") || !message.contains("content")) {
|
||||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue