tool-call: let the tool call handler expand chat template, moving builtin_tools down as extra_context

This commit is contained in:
ochafik 2024-09-28 17:46:36 +01:00
parent 0c85bc7a8f
commit d983516f40
7 changed files with 64 additions and 10 deletions

View file

@ -78,7 +78,8 @@ llama_chat_template llama_chat_template::from_model(
std::string llama_chat_template::apply( std::string llama_chat_template::apply(
const json & messages, const json & messages,
const json & tools, const json & tools,
bool add_generation_prompt) const bool add_generation_prompt,
const json & extra_context) const
{ {
auto actual_messages = messages; auto actual_messages = messages;
@ -141,8 +142,12 @@ std::string llama_chat_template::apply(
if (!tools.is_null()) { if (!tools.is_null()) {
auto tools_val = minja::Value(tools); auto tools_val = minja::Value(tools);
context->set("tools", tools_val); context->set("tools", tools_val);
auto builtin_tools = minja::Value(json {"wolfram_alpha", "brave_search"}); }
context->set("builtin_tools", builtin_tools); if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) {
minja::Value val(kv.value());
context->set(kv.key(), val);
}
} }
return _template_root->render(context); return _template_root->render(context);

View file

@ -48,5 +48,6 @@ class llama_chat_template {
std::string apply( std::string apply(
const nlohmann::ordered_json & messages, const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools, const nlohmann::ordered_json & tools,
bool add_generation_prompt) const; bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const;
}; };

View file

@ -218,6 +218,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
const llama_chat_template & tmpl, const llama_chat_template & tmpl,
bool allow_content, bool allow_content,
bool parallel_tool_calls, bool parallel_tool_calls,
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools) const nlohmann::ordered_json & tools)
{ {
llama_tool_call_handler handler; llama_tool_call_handler handler;
@ -255,6 +256,9 @@ llama_tool_call_handler llama_tool_call_handler_init(
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | ")); builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
}); });
handler.additional_stop_words.push_back("<|eom_id|>"); handler.additional_stop_words.push_back("<|eom_id|>");
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true, {
{"builtin_tools", builtin_tools},
});
break; break;
} }
case llama_tool_call_style::FunctionaryV3Llama3: { case llama_tool_call_style::FunctionaryV3Llama3: {
@ -284,6 +288,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
builder.add_rule("root", first_rule); builder.add_rule("root", first_rule);
} }
}); });
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
// handler.parser = parse_functionary_3_2_tool_calls; // handler.parser = parse_functionary_3_2_tool_calls;
break; break;
} }
@ -313,6 +318,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
handler.grammar_trigger_words.push_back("<function="); handler.grammar_trigger_words.push_back("<function=");
} }
}); });
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
// handler.parser = parse_functionary_3_2_tool_calls; // handler.parser = parse_functionary_3_2_tool_calls;
break; break;
} }
@ -342,6 +348,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
handler.grammar_trigger_words.push_back("<tool_call>"); handler.grammar_trigger_words.push_back("<tool_call>");
} }
}); });
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
break; break;
} }
default: default:

View file

@ -18,6 +18,7 @@ struct llama_tool_calls {
}; };
struct llama_tool_call_handler { struct llama_tool_call_handler {
std::string prompt;
std::string grammar; std::string grammar;
std::vector<std::string> grammar_trigger_words; std::vector<std::string> grammar_trigger_words;
std::vector<std::string> additional_stop_words; std::vector<std::string> additional_stop_words;
@ -29,4 +30,5 @@ llama_tool_call_handler llama_tool_call_handler_init(
const llama_chat_template & tmpl, const llama_chat_template & tmpl,
bool allow_content, bool allow_content,
bool parallel_tool_calls, bool parallel_tool_calls,
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools); const nlohmann::ordered_json & tools);

View file

@ -372,7 +372,8 @@ static json oaicompat_completion_params_parse(
llama_params["parse_tool_calls"] = true; llama_params["parse_tool_calls"] = true;
llama_params["parallel_tool_calls"] = parallel_tool_calls; llama_params["parallel_tool_calls"] = parallel_tool_calls;
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, tools); auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools);
llama_params["prompt"] = handler.prompt;
for (const auto & stop : handler.additional_stop_words) { for (const auto & stop : handler.additional_stop_words) {
llama_params["stop"].push_back(stop); llama_params["stop"].push_back(stop);
@ -390,8 +391,9 @@ static json oaicompat_completion_params_parse(
} }
llama_params["grammar"] = handler.grammar; llama_params["grammar"] = handler.grammar;
} }
} } else {
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
}
} else { } else {
llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages")); llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages"));
} }

View file

@ -104,7 +104,10 @@ static void test_jinja_templates() {
actual = tmpl.apply( actual = tmpl.apply(
ctx.at("messages"), ctx.at("messages"),
ctx.contains("tools") ? ctx.at("tools") : json(), ctx.contains("tools") ? ctx.at("tools") : json(),
ctx.at("add_generation_prompt")); ctx.at("add_generation_prompt"),
ctx.contains("tools") ? json {
{"builtin_tools", {"wolfram_alpha", "brave_search"}}
} : json());
} catch (const std::runtime_error & e) { } catch (const std::runtime_error & e) {
actual = "ERROR: " + std::string(e.what()); actual = "ERROR: " + std::string(e.what());
} }

View file

@ -16,6 +16,20 @@ static void assert_equals(const std::string & expected, const std::string & actu
} }
} }
static std::string read_file(const std::string &path) {
std::ifstream fs(path, std::ios_base::binary);
if (!fs.is_open()) {
throw std::runtime_error("Failed to open file: " + path);
}
fs.seekg(0, std::ios_base::end);
auto size = fs.tellg();
fs.seekg(0);
std::string out;
out.resize(static_cast<size_t>(size));
fs.read(&out[0], static_cast<std::streamsize>(size));
return out;
}
/* /*
cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call
*/ */
@ -53,6 +67,23 @@ int main() {
"required": ["arg1"] "required": ["arg1"]
} }
} }
},
{
"type": "function",
"function": {
"name": "ipython",
"description": "a python interpreter",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The code."
}
},
"required": ["code"]
}
}
} }
])"); ])");
json request = { json request = {
@ -83,12 +114,14 @@ int main() {
}} }}
}}); }});
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools, test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
">>>test\n{ } \n ", ">>>special_function\n{\"arg1\": 1}\n ",
"", "",
json {{ json {{
{"function", { {"function", {
{"name", "test"}, {"name", "special_function"},
{"arguments", "{}"} {"arguments", (json {
{"arg1", 1}
}).dump()}
}} }}
}}); }});
@ -158,5 +191,6 @@ int main() {
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array()); "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
std::cout << "[tool-call] All tests passed!" << std::endl;
return 0; return 0;
} }