diff --git a/common/tool-call.cpp b/common/tool-call.cpp index af2d95cf8..8304069ac 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -253,7 +253,9 @@ llama_tool_call_handler llama_tool_call_handler_init( }); // handler.parser = parse_functionary_3_2_tool_calls; } else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) { + // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt + // TODO: handle tool {type: code_interpreter} as python handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { std::vector tool_rules; for (size_t i = 0, n = tools.size(); i < n; i++) { @@ -261,8 +263,14 @@ llama_tool_call_handler llama_tool_call_handler_init( const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; - auto tool_rule = builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\""); - tool_rules.push_back(tool_rule); + if (name == "python") { + tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); + if (allow_content) { + handler.grammar_trigger_words.push_back("<|python_tag|>"); + } + } else { + tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\"")); + } } auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);