From d7ec84f78c884a9bd024fab0dbbafb474efdc924 Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 26 Sep 2024 06:51:46 +0100 Subject: [PATCH] `tool-call`: allow <|python_tag|> in functionary-medium-3.1 --- common/tool-call.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/common/tool-call.cpp b/common/tool-call.cpp index af2d95cf8..8304069ac 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -253,7 +253,9 @@ llama_tool_call_handler llama_tool_call_handler_init( }); // handler.parser = parse_functionary_3_2_tool_calls; } else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) { + // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt + // TODO: handle tool {type: code_interpreter} as python handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { std::vector tool_rules; for (size_t i = 0, n = tools.size(); i < n; i++) { @@ -261,8 +263,14 @@ llama_tool_call_handler llama_tool_call_handler_init( const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; - auto tool_rule = builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\""); - tool_rules.push_back(tool_rule); + if (name == "python") { + tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); + if (allow_content) { + handler.grammar_trigger_words.push_back("<|python_tag|>"); + } + } else { + tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\"")); + } } auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);