tool-call
: behaviour-based detection of template features
This commit is contained in:
parent
e8d9d711f6
commit
c395d4804f
1 changed files with 30 additions and 7 deletions
|
@ -32,22 +32,45 @@ class chat_template {
|
||||||
std::string _eos_token;
|
std::string _eos_token;
|
||||||
std::shared_ptr<minja::TemplateNode> _template_root;
|
std::shared_ptr<minja::TemplateNode> _template_root;
|
||||||
|
|
||||||
|
bool renders_needles(
|
||||||
|
const std::vector<std::string> & needles,
|
||||||
|
const nlohmann::ordered_json & messages,
|
||||||
|
const nlohmann::ordered_json & tools,
|
||||||
|
bool add_generation_prompt,
|
||||||
|
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||||
|
{
|
||||||
|
try {
|
||||||
|
auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
|
||||||
|
for (const auto & needle : needles) {
|
||||||
|
if (prompt.find(needle) == std::string::npos) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
|
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
|
||||||
: _source(source), _bos_token(bos_token), _eos_token(eos_token)
|
: _source(source), _bos_token(bos_token), _eos_token(eos_token)
|
||||||
{
|
{
|
||||||
_supports_tools = source.find("tools") != std::string::npos;
|
|
||||||
_requires_object_arguments =
|
|
||||||
source.find("tool_call.arguments | items") != std::string::npos
|
|
||||||
|| source.find("tool_call.arguments | tojson") != std::string::npos;
|
|
||||||
_supports_system_role = source.find("System role not supported") == std::string::npos;
|
|
||||||
_supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos;
|
|
||||||
|
|
||||||
_template_root = minja::Parser::parse(_source, {
|
_template_root = minja::Parser::parse(_source, {
|
||||||
/* .trim_blocks = */ true,
|
/* .trim_blocks = */ true,
|
||||||
/* .lstrip_blocks = */ true,
|
/* .lstrip_blocks = */ true,
|
||||||
/* .keep_trailing_newline = */ false,
|
/* .keep_trailing_newline = */ false,
|
||||||
});
|
});
|
||||||
|
_supports_tools = source.find("tools") != std::string::npos;
|
||||||
|
_requires_object_arguments =
|
||||||
|
source.find("tool_call.arguments | items") != std::string::npos
|
||||||
|
|| source.find("tool_call.arguments | tojson") != std::string::npos;
|
||||||
|
_supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos;
|
||||||
|
|
||||||
|
_supports_system_role = renders_needles({"<System Needle>"}, {
|
||||||
|
{{"role", "system"}, {"content", "<System Needle>"}},
|
||||||
|
{{"role", "user"}, {"content", "Hey"}}
|
||||||
|
}, {}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string & source() const { return _source; }
|
const std::string & source() const { return _source; }
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue