tool-call: let the tool call handler expand chat template, moving builtin_tools down as extra_context

This commit is contained in:
ochafik 2024-09-28 17:46:36 +01:00
parent 0c85bc7a8f
commit d983516f40
7 changed files with 64 additions and 10 deletions

View file

@ -78,7 +78,8 @@ llama_chat_template llama_chat_template::from_model(
std::string llama_chat_template::apply(
const json & messages,
const json & tools,
bool add_generation_prompt) const
bool add_generation_prompt,
const json & extra_context) const
{
auto actual_messages = messages;
@ -141,8 +142,12 @@ std::string llama_chat_template::apply(
if (!tools.is_null()) {
auto tools_val = minja::Value(tools);
context->set("tools", tools_val);
auto builtin_tools = minja::Value(json {"wolfram_alpha", "brave_search"});
context->set("builtin_tools", builtin_tools);
}
if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) {
minja::Value val(kv.value());
context->set(kv.key(), val);
}
}
return _template_root->render(context);

View file

@ -48,5 +48,6 @@ class llama_chat_template {
std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt) const;
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const;
};

View file

@ -218,6 +218,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
const llama_chat_template & tmpl,
bool allow_content,
bool parallel_tool_calls,
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools)
{
llama_tool_call_handler handler;
@ -255,6 +256,9 @@ llama_tool_call_handler llama_tool_call_handler_init(
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
});
handler.additional_stop_words.push_back("<|eom_id|>");
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true, {
{"builtin_tools", builtin_tools},
});
break;
}
case llama_tool_call_style::FunctionaryV3Llama3: {
@ -284,6 +288,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
builder.add_rule("root", first_rule);
}
});
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
// handler.parser = parse_functionary_3_2_tool_calls;
break;
}
@ -313,6 +318,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
handler.grammar_trigger_words.push_back("<function=");
}
});
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
// handler.parser = parse_functionary_3_2_tool_calls;
break;
}
@ -342,6 +348,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
handler.grammar_trigger_words.push_back("<tool_call>");
}
});
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
break;
}
default:

View file

@ -18,6 +18,7 @@ struct llama_tool_calls {
};
struct llama_tool_call_handler {
std::string prompt;
std::string grammar;
std::vector<std::string> grammar_trigger_words;
std::vector<std::string> additional_stop_words;
@ -29,4 +30,5 @@ llama_tool_call_handler llama_tool_call_handler_init(
const llama_chat_template & tmpl,
bool allow_content,
bool parallel_tool_calls,
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools);

View file

@ -372,7 +372,8 @@ static json oaicompat_completion_params_parse(
llama_params["parse_tool_calls"] = true;
llama_params["parallel_tool_calls"] = parallel_tool_calls;
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, tools);
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools);
llama_params["prompt"] = handler.prompt;
for (const auto & stop : handler.additional_stop_words) {
llama_params["stop"].push_back(stop);
@ -390,8 +391,9 @@ static json oaicompat_completion_params_parse(
}
llama_params["grammar"] = handler.grammar;
}
} else {
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
}
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
} else {
llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages"));
}

View file

@ -104,7 +104,10 @@ static void test_jinja_templates() {
actual = tmpl.apply(
ctx.at("messages"),
ctx.contains("tools") ? ctx.at("tools") : json(),
ctx.at("add_generation_prompt"));
ctx.at("add_generation_prompt"),
ctx.contains("tools") ? json {
{"builtin_tools", {"wolfram_alpha", "brave_search"}}
} : json());
} catch (const std::runtime_error & e) {
actual = "ERROR: " + std::string(e.what());
}

View file

@ -16,6 +16,20 @@ static void assert_equals(const std::string & expected, const std::string & actu
}
}
static std::string read_file(const std::string &path) {
std::ifstream fs(path, std::ios_base::binary);
if (!fs.is_open()) {
throw std::runtime_error("Failed to open file: " + path);
}
fs.seekg(0, std::ios_base::end);
auto size = fs.tellg();
fs.seekg(0);
std::string out;
out.resize(static_cast<size_t>(size));
fs.read(&out[0], static_cast<std::streamsize>(size));
return out;
}
/*
cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call
*/
@ -53,6 +67,23 @@ int main() {
"required": ["arg1"]
}
}
},
{
"type": "function",
"function": {
"name": "ipython",
"description": "a python interpreter",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The code."
}
},
"required": ["code"]
}
}
}
])");
json request = {
@ -83,12 +114,14 @@ int main() {
}}
}});
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
">>>test\n{ } \n ",
">>>special_function\n{\"arg1\": 1}\n ",
"",
json {{
{"function", {
{"name", "test"},
{"arguments", "{}"}
{"name", "special_function"},
{"arguments", (json {
{"arg1", 1}
}).dump()}
}}
}});
@ -158,5 +191,6 @@ int main() {
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
std::cout << "[tool-call] All tests passed!" << std::endl;
return 0;
}