tool-call
: let the tool call handler expand chat template, moving builtin_tools down as extra_context
This commit is contained in:
parent
0c85bc7a8f
commit
d983516f40
7 changed files with 64 additions and 10 deletions
|
@ -78,7 +78,8 @@ llama_chat_template llama_chat_template::from_model(
|
|||
std::string llama_chat_template::apply(
|
||||
const json & messages,
|
||||
const json & tools,
|
||||
bool add_generation_prompt) const
|
||||
bool add_generation_prompt,
|
||||
const json & extra_context) const
|
||||
{
|
||||
auto actual_messages = messages;
|
||||
|
||||
|
@ -141,8 +142,12 @@ std::string llama_chat_template::apply(
|
|||
if (!tools.is_null()) {
|
||||
auto tools_val = minja::Value(tools);
|
||||
context->set("tools", tools_val);
|
||||
auto builtin_tools = minja::Value(json {"wolfram_alpha", "brave_search"});
|
||||
context->set("builtin_tools", builtin_tools);
|
||||
}
|
||||
if (!extra_context.is_null()) {
|
||||
for (auto & kv : extra_context.items()) {
|
||||
minja::Value val(kv.value());
|
||||
context->set(kv.key(), val);
|
||||
}
|
||||
}
|
||||
|
||||
return _template_root->render(context);
|
||||
|
|
|
@ -48,5 +48,6 @@ class llama_chat_template {
|
|||
std::string apply(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt) const;
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const;
|
||||
};
|
||||
|
|
|
@ -218,6 +218,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
|||
const llama_chat_template & tmpl,
|
||||
bool allow_content,
|
||||
bool parallel_tool_calls,
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools)
|
||||
{
|
||||
llama_tool_call_handler handler;
|
||||
|
@ -255,6 +256,9 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
|||
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
|
||||
});
|
||||
handler.additional_stop_words.push_back("<|eom_id|>");
|
||||
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true, {
|
||||
{"builtin_tools", builtin_tools},
|
||||
});
|
||||
break;
|
||||
}
|
||||
case llama_tool_call_style::FunctionaryV3Llama3: {
|
||||
|
@ -284,6 +288,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
|||
builder.add_rule("root", first_rule);
|
||||
}
|
||||
});
|
||||
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
|
||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||
break;
|
||||
}
|
||||
|
@ -313,6 +318,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
|||
handler.grammar_trigger_words.push_back("<function=");
|
||||
}
|
||||
});
|
||||
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
|
||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||
break;
|
||||
}
|
||||
|
@ -342,6 +348,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
|||
handler.grammar_trigger_words.push_back("<tool_call>");
|
||||
}
|
||||
});
|
||||
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
|
|
@ -18,6 +18,7 @@ struct llama_tool_calls {
|
|||
};
|
||||
|
||||
struct llama_tool_call_handler {
|
||||
std::string prompt;
|
||||
std::string grammar;
|
||||
std::vector<std::string> grammar_trigger_words;
|
||||
std::vector<std::string> additional_stop_words;
|
||||
|
@ -29,4 +30,5 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
|||
const llama_chat_template & tmpl,
|
||||
bool allow_content,
|
||||
bool parallel_tool_calls,
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools);
|
||||
|
|
|
@ -372,7 +372,8 @@ static json oaicompat_completion_params_parse(
|
|||
llama_params["parse_tool_calls"] = true;
|
||||
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
||||
|
||||
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, tools);
|
||||
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools);
|
||||
llama_params["prompt"] = handler.prompt;
|
||||
|
||||
for (const auto & stop : handler.additional_stop_words) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
|
@ -390,8 +391,9 @@ static json oaicompat_completion_params_parse(
|
|||
}
|
||||
llama_params["grammar"] = handler.grammar;
|
||||
}
|
||||
} else {
|
||||
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
||||
}
|
||||
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
||||
} else {
|
||||
llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages"));
|
||||
}
|
||||
|
|
|
@ -104,7 +104,10 @@ static void test_jinja_templates() {
|
|||
actual = tmpl.apply(
|
||||
ctx.at("messages"),
|
||||
ctx.contains("tools") ? ctx.at("tools") : json(),
|
||||
ctx.at("add_generation_prompt"));
|
||||
ctx.at("add_generation_prompt"),
|
||||
ctx.contains("tools") ? json {
|
||||
{"builtin_tools", {"wolfram_alpha", "brave_search"}}
|
||||
} : json());
|
||||
} catch (const std::runtime_error & e) {
|
||||
actual = "ERROR: " + std::string(e.what());
|
||||
}
|
||||
|
|
|
@ -16,6 +16,20 @@ static void assert_equals(const std::string & expected, const std::string & actu
|
|||
}
|
||||
}
|
||||
|
||||
static std::string read_file(const std::string &path) {
|
||||
std::ifstream fs(path, std::ios_base::binary);
|
||||
if (!fs.is_open()) {
|
||||
throw std::runtime_error("Failed to open file: " + path);
|
||||
}
|
||||
fs.seekg(0, std::ios_base::end);
|
||||
auto size = fs.tellg();
|
||||
fs.seekg(0);
|
||||
std::string out;
|
||||
out.resize(static_cast<size_t>(size));
|
||||
fs.read(&out[0], static_cast<std::streamsize>(size));
|
||||
return out;
|
||||
}
|
||||
|
||||
/*
|
||||
cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call
|
||||
*/
|
||||
|
@ -53,6 +67,23 @@ int main() {
|
|||
"required": ["arg1"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ipython",
|
||||
"description": "a python interpreter",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code."
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
}
|
||||
])");
|
||||
json request = {
|
||||
|
@ -83,12 +114,14 @@ int main() {
|
|||
}}
|
||||
}});
|
||||
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
|
||||
">>>test\n{ } \n ",
|
||||
">>>special_function\n{\"arg1\": 1}\n ",
|
||||
"",
|
||||
json {{
|
||||
{"function", {
|
||||
{"name", "test"},
|
||||
{"arguments", "{}"}
|
||||
{"name", "special_function"},
|
||||
{"arguments", (json {
|
||||
{"arg1", 1}
|
||||
}).dump()}
|
||||
}}
|
||||
}});
|
||||
|
||||
|
@ -158,5 +191,6 @@ int main() {
|
|||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
|
||||
|
||||
std::cout << "[tool-call] All tests passed!" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue