tool-call
: let the tool call handler expand chat template, moving builtin_tools down as extra_context
This commit is contained in:
parent
0c85bc7a8f
commit
d983516f40
7 changed files with 64 additions and 10 deletions
|
@ -78,7 +78,8 @@ llama_chat_template llama_chat_template::from_model(
|
||||||
std::string llama_chat_template::apply(
|
std::string llama_chat_template::apply(
|
||||||
const json & messages,
|
const json & messages,
|
||||||
const json & tools,
|
const json & tools,
|
||||||
bool add_generation_prompt) const
|
bool add_generation_prompt,
|
||||||
|
const json & extra_context) const
|
||||||
{
|
{
|
||||||
auto actual_messages = messages;
|
auto actual_messages = messages;
|
||||||
|
|
||||||
|
@ -141,8 +142,12 @@ std::string llama_chat_template::apply(
|
||||||
if (!tools.is_null()) {
|
if (!tools.is_null()) {
|
||||||
auto tools_val = minja::Value(tools);
|
auto tools_val = minja::Value(tools);
|
||||||
context->set("tools", tools_val);
|
context->set("tools", tools_val);
|
||||||
auto builtin_tools = minja::Value(json {"wolfram_alpha", "brave_search"});
|
}
|
||||||
context->set("builtin_tools", builtin_tools);
|
if (!extra_context.is_null()) {
|
||||||
|
for (auto & kv : extra_context.items()) {
|
||||||
|
minja::Value val(kv.value());
|
||||||
|
context->set(kv.key(), val);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return _template_root->render(context);
|
return _template_root->render(context);
|
||||||
|
|
|
@ -48,5 +48,6 @@ class llama_chat_template {
|
||||||
std::string apply(
|
std::string apply(
|
||||||
const nlohmann::ordered_json & messages,
|
const nlohmann::ordered_json & messages,
|
||||||
const nlohmann::ordered_json & tools,
|
const nlohmann::ordered_json & tools,
|
||||||
bool add_generation_prompt) const;
|
bool add_generation_prompt,
|
||||||
|
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const;
|
||||||
};
|
};
|
||||||
|
|
|
@ -218,6 +218,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
const llama_chat_template & tmpl,
|
const llama_chat_template & tmpl,
|
||||||
bool allow_content,
|
bool allow_content,
|
||||||
bool parallel_tool_calls,
|
bool parallel_tool_calls,
|
||||||
|
const nlohmann::ordered_json & messages,
|
||||||
const nlohmann::ordered_json & tools)
|
const nlohmann::ordered_json & tools)
|
||||||
{
|
{
|
||||||
llama_tool_call_handler handler;
|
llama_tool_call_handler handler;
|
||||||
|
@ -255,6 +256,9 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
|
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
|
||||||
});
|
});
|
||||||
handler.additional_stop_words.push_back("<|eom_id|>");
|
handler.additional_stop_words.push_back("<|eom_id|>");
|
||||||
|
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true, {
|
||||||
|
{"builtin_tools", builtin_tools},
|
||||||
|
});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case llama_tool_call_style::FunctionaryV3Llama3: {
|
case llama_tool_call_style::FunctionaryV3Llama3: {
|
||||||
|
@ -284,6 +288,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
builder.add_rule("root", first_rule);
|
builder.add_rule("root", first_rule);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
|
||||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -313,6 +318,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
handler.grammar_trigger_words.push_back("<function=");
|
handler.grammar_trigger_words.push_back("<function=");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
|
||||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -342,6 +348,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
handler.grammar_trigger_words.push_back("<tool_call>");
|
handler.grammar_trigger_words.push_back("<tool_call>");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -18,6 +18,7 @@ struct llama_tool_calls {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_tool_call_handler {
|
struct llama_tool_call_handler {
|
||||||
|
std::string prompt;
|
||||||
std::string grammar;
|
std::string grammar;
|
||||||
std::vector<std::string> grammar_trigger_words;
|
std::vector<std::string> grammar_trigger_words;
|
||||||
std::vector<std::string> additional_stop_words;
|
std::vector<std::string> additional_stop_words;
|
||||||
|
@ -29,4 +30,5 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
const llama_chat_template & tmpl,
|
const llama_chat_template & tmpl,
|
||||||
bool allow_content,
|
bool allow_content,
|
||||||
bool parallel_tool_calls,
|
bool parallel_tool_calls,
|
||||||
|
const nlohmann::ordered_json & messages,
|
||||||
const nlohmann::ordered_json & tools);
|
const nlohmann::ordered_json & tools);
|
||||||
|
|
|
@ -372,7 +372,8 @@ static json oaicompat_completion_params_parse(
|
||||||
llama_params["parse_tool_calls"] = true;
|
llama_params["parse_tool_calls"] = true;
|
||||||
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
||||||
|
|
||||||
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, tools);
|
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools);
|
||||||
|
llama_params["prompt"] = handler.prompt;
|
||||||
|
|
||||||
for (const auto & stop : handler.additional_stop_words) {
|
for (const auto & stop : handler.additional_stop_words) {
|
||||||
llama_params["stop"].push_back(stop);
|
llama_params["stop"].push_back(stop);
|
||||||
|
@ -390,8 +391,9 @@ static json oaicompat_completion_params_parse(
|
||||||
}
|
}
|
||||||
llama_params["grammar"] = handler.grammar;
|
llama_params["grammar"] = handler.grammar;
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages"));
|
llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages"));
|
||||||
}
|
}
|
||||||
|
|
|
@ -104,7 +104,10 @@ static void test_jinja_templates() {
|
||||||
actual = tmpl.apply(
|
actual = tmpl.apply(
|
||||||
ctx.at("messages"),
|
ctx.at("messages"),
|
||||||
ctx.contains("tools") ? ctx.at("tools") : json(),
|
ctx.contains("tools") ? ctx.at("tools") : json(),
|
||||||
ctx.at("add_generation_prompt"));
|
ctx.at("add_generation_prompt"),
|
||||||
|
ctx.contains("tools") ? json {
|
||||||
|
{"builtin_tools", {"wolfram_alpha", "brave_search"}}
|
||||||
|
} : json());
|
||||||
} catch (const std::runtime_error & e) {
|
} catch (const std::runtime_error & e) {
|
||||||
actual = "ERROR: " + std::string(e.what());
|
actual = "ERROR: " + std::string(e.what());
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,20 @@ static void assert_equals(const std::string & expected, const std::string & actu
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string read_file(const std::string &path) {
|
||||||
|
std::ifstream fs(path, std::ios_base::binary);
|
||||||
|
if (!fs.is_open()) {
|
||||||
|
throw std::runtime_error("Failed to open file: " + path);
|
||||||
|
}
|
||||||
|
fs.seekg(0, std::ios_base::end);
|
||||||
|
auto size = fs.tellg();
|
||||||
|
fs.seekg(0);
|
||||||
|
std::string out;
|
||||||
|
out.resize(static_cast<size_t>(size));
|
||||||
|
fs.read(&out[0], static_cast<std::streamsize>(size));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call
|
cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call
|
||||||
*/
|
*/
|
||||||
|
@ -53,6 +67,23 @@ int main() {
|
||||||
"required": ["arg1"]
|
"required": ["arg1"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "ipython",
|
||||||
|
"description": "a python interpreter",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"code": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The code."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["code"]
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
])");
|
])");
|
||||||
json request = {
|
json request = {
|
||||||
|
@ -83,12 +114,14 @@ int main() {
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
|
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
|
||||||
">>>test\n{ } \n ",
|
">>>special_function\n{\"arg1\": 1}\n ",
|
||||||
"",
|
"",
|
||||||
json {{
|
json {{
|
||||||
{"function", {
|
{"function", {
|
||||||
{"name", "test"},
|
{"name", "special_function"},
|
||||||
{"arguments", "{}"}
|
{"arguments", (json {
|
||||||
|
{"arg1", 1}
|
||||||
|
}).dump()}
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
|
|
||||||
|
@ -158,5 +191,6 @@ int main() {
|
||||||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
|
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
|
||||||
|
|
||||||
|
std::cout << "[tool-call] All tests passed!" << std::endl;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue