From ef9efc9ed3a53aa55f11135e646773d3fa8fef6f Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 28 Jan 2025 01:04:06 +0000 Subject: [PATCH] Fix Llama 3.1 (incl. constrained builtin tools e.g. `<|python_tag|>foo.call(arg=vallue)`) --- common/chat-handler.cpp | 97 +++++++++++++++----- examples/server/tests/unit/test_tool_call.py | 12 +-- tests/test-chat-handler.cpp | 39 +++++++- 3 files changed, 116 insertions(+), 32 deletions(-) diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index 19b11d689..2348fab55 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -207,7 +207,6 @@ static void foreach_function(const json & tools, const std::function([](const std::string & input) { - return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); - }); + return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); + }); return data; } -static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) { - fprintf(stderr, "[%s]\n", __func__); - // TODO: get from request body. - auto builtin_tools = json {"wolfram_alpha", "brave_search"}; - common_chat_data data; +static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { + if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) { + throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); + } + const auto & parameters_properties = parameters.at("properties"); + const auto & parameters_required = parameters.at("required"); + for (const auto & prop : expected_properties) { + if (!parameters_properties.contains(prop)) { + throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); + } + if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) { + throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); + } + } + if (parameters_properties.size() != expected_properties.size()) { + throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", ")); + } +} +static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) { + auto builtin_tools = json::array(); + common_chat_data data; data.grammar_lazy = params.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; + auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { + if (name == "wolfram_alpha") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py + expect_tool_parameters(name, parameters, {"query"}); + } else if (name == "web_search" || name == "brave_search") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py + expect_tool_parameters(name, parameters, {"query"}); + } else if (name == "python" || name == "code_interpreter") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py + expect_tool_parameters(name, parameters, {"code"}); + } else { + return false; + } + + std::vector kvs; + for (const auto & [key, value] : parameters.at("properties").items()) { + kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); + } + + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\"")); + builtin_tools.push_back(name); + + return true; + }; + foreach_function(params.tools, [&](const json & tool) { const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; + + // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime + if (handle_builtin_tool(name, parameters)) { + return; + } builder.resolve_refs(parameters); tool_rules.push_back( builder.add_rule( @@ -388,30 +432,42 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c " \"}\"")); data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); }); - tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*")); - data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); + if (!builtin_tools.empty()) { + data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); + } builder.add_rule("root", string_join(tool_rules, " | ")); }, grammar_options); data.additional_stops.push_back("<|eom_id|>"); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, { - {"builtin_tools", builtin_tools}, + {"tools_in_user_message", false}, + {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, }); data.format = "llama 3.1 tool calls"; data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": "); static std::regex close_regex("\\}"); - static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\((.*)\)"); + static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)"); std::smatch match; if (std::regex_match(input, match, builtin_call_regex)) { - auto arguments = json::parse("[" + match[2].str() + "]"); + auto name = match[1].str(); + auto raw_args = match[2].str(); + + // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing. + auto it_eq = raw_args.find('='); + auto arg_name = raw_args.substr(0, it_eq); + auto arg_value_str = raw_args.substr(it_eq + 1); + auto arg_value = json::parse(arg_value_str); + return { /* .role = */ "assistant", /* .content = */ match.prefix().str(), /* .tool_calls = */ { { /* .name = */ match[1], - /* .arguments = */ arguments.dump(), + /* .arguments = */ (json { + {arg_name, arg_value}, + }).dump(), /* .id = */ "", }, }, @@ -423,7 +479,6 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c } static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) { - fprintf(stderr, "[%s]\n", __func__); common_chat_data data; data.grammar_lazy = params.tool_choice != "required"; @@ -462,7 +517,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_ } static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { - fprintf(stderr, "[%s]\n", __func__); common_chat_data data; data.grammar_lazy = params.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { @@ -490,7 +544,6 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat } static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { - fprintf(stderr, "[%s]\n", __func__); common_chat_data data; data.grammar_lazy = params.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { @@ -529,7 +582,6 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_ } static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { - fprintf(stderr, "[%s]\n", __func__); // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar common_chat_data data; @@ -574,7 +626,6 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common } static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { - fprintf(stderr, "[%s]\n", __func__); // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt common_chat_data data; @@ -651,7 +702,6 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons } static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { - fprintf(stderr, "[%s]\n", __func__); common_chat_data data; // (content)?({"name": "foo", "arguments": {"a": 1}})* data.grammar_lazy = params.tool_choice != "required"; @@ -705,9 +755,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha if (!parse_json(it, end, call)) { throw std::runtime_error("Failed to parse json tool call"); } + const auto & arguments = call["arguments"]; result.tool_calls.push_back({ call["name"], - call["arguments"].dump(), + arguments.dump(), + // arguments.is_string() ? arguments.get() : arguments.dump(), /* id= */ "", }); rit = {it, end, middle_pattern}; @@ -734,7 +786,6 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha } static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) { - fprintf(stderr, "[%s]\n", __func__); common_chat_data data; data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.format = "content-only"; diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index 0c9dc6bd4..86358d7d1 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -63,6 +63,8 @@ WEATHER_TOOL = { @pytest.mark.parametrize("template_name,tool,argument_key", [ + ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"), ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), @@ -78,8 +80,6 @@ WEATHER_TOOL = { ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"), ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"), # TODO: fix these - # ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), - # ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), ]) def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None): n_predict = 512 @@ -118,6 +118,8 @@ def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argu @pytest.mark.slow @pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [ + (TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), + (PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), @@ -139,8 +141,6 @@ def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argu # TODO: fix these # (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None), # (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None), - # (TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), - # (PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), ]) def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): n_predict = 512 @@ -218,6 +218,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools: @pytest.mark.slow @pytest.mark.parametrize("hf_repo,hf_file,template_override", [ + ("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), ("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), ("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), ("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), @@ -229,7 +230,6 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools: ("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), # TODO: fix these # ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None), - # ("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), ]) def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): global server @@ -267,6 +267,7 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[ @pytest.mark.slow @pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [ + ('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), (None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), (None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), @@ -277,7 +278,6 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[ (None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None), # TODO: fix these - # ('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None), ]) def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): diff --git a/tests/test-chat-handler.cpp b/tests/test-chat-handler.cpp index cccc98db8..a5c28e958 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat-handler.cpp @@ -119,7 +119,25 @@ const auto python_tool = json::parse(R"({ } } })"); +const auto code_interpreter_tool = json::parse(R"({ + "type": "function", + "function": { + "name": "code_interpreter", + "description": "an ipython interpreter", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute." + } + }, + "required": ["code"] + } + } +})"); const json tools = {special_function_tool, python_tool}; +const json llama_3_1_tools = {special_function_tool, code_interpreter_tool}; // static void test_parsing() { // json request = { @@ -427,6 +445,19 @@ static void test_grammars() { }}, }}} }; + auto code_interpreter_tool_call_message = json { + {"role", "assistant"}, + {"content", {}}, + {"tool_calls", json {{ + {"type", "function"}, + {"function", { + {"name", "code_interpreter"}, + {"arguments", { + {"code", "print('hey')"}, + }}, + }}, + }}} + }; common_chat_params no_tools_params; @@ -494,10 +525,12 @@ static void test_grammars() { const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params)); - test_template(tmpl, end_tokens, text_message, tools); + // assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params)); + // test_template(tmpl, end_tokens, text_message, tools); + test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools); + test_template(tmpl, end_tokens, python_tool_call_message, tools); test_template(tmpl, end_tokens, tool_call_message, tools); - test_template(tmpl, end_tokens, python_tool_call_message, tools); + test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools); } { const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", "");