From da606d8d41cb700de11160594e93a235f6869798 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 27 Jan 2025 14:19:20 +0000 Subject: [PATCH] tool-call: remove nonsensical code_interpreter code --- common/chat-handler.cpp | 308 ++++++++---------- common/chat-template.hpp | 28 +- examples/server/server.cpp | 7 +- .../server/tests/unit/test_chat_completion.py | 62 ++-- 4 files changed, 165 insertions(+), 240 deletions(-) diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index 511fa1aef..74805f222 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -1,6 +1,7 @@ #include "chat-handler.hpp" #include "chat-template.hpp" #include "json-schema-to-grammar.h" +#include "log.h" #include "minja.hpp" const common_grammar_options grammar_options { @@ -58,7 +59,7 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. * Aggregates the prefix, suffix and in-between text into the content. */ -static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool has_python) { +static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false) { std::smatch match; common_chat_msg result; @@ -69,15 +70,10 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri std::vector tool_names; if (check_names) { for (const auto & tool : tools) { - if (!tool.contains("type")) { + if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) { continue; } - std::string type = tool.at("type"); - if (type == "function") { - tool_names.push_back(tool["function"]["name"]); - } else if (type == "code_interpreter") { - tool_names.push_back("python"); - } + tool_names.push_back(tool["function"]["name"]); } } @@ -102,7 +98,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri json arguments; if (!parse_json(it, end, arguments)) { - if (has_python && name == "python" && std::regex_match("", close_regex)) { + if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) { std::string src(it, end); result.tool_calls.push_back({name, src, /* id= */ ""}); break; @@ -207,42 +203,22 @@ public: } }; -const auto python_tool = json::parse(R"({ - "type": "function", - "function": { - "name": "python", - "description": "an ipython interpreter", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "Python code to execute." - } - }, - "required": ["code"] - } - } -})"); - -static void foreach_normalized_tool(const json & tools, const std::function & fn) { +static void foreach_function(const json & tools, const std::function & fn) { for (const auto & tool : tools) { - if (!tool.contains("type")) { + if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) { + LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str()); continue; } - if (tool["type"] == "code_interpreter") { - fn(python_tool); - } else if (tool["type"] == "function" && tool.contains("function")) { - fn(tool); - } + fn(tool); } } static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { + fprintf(stderr, "[%s]\n", __func__); common_chat_data data; auto tool_call_schemas = json::array(); - foreach_normalized_tool(params.tools, [&](const json & tool) { + foreach_function(params.tools, [&](const json & tool) { const auto & function = tool["function"]; auto tool_schema = json { {"type", "object"}, @@ -348,10 +324,11 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem } static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { + fprintf(stderr, "[%s]\n", __func__); common_chat_data data; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); - foreach_normalized_tool(params.tools, [&](const json & tool) { + foreach_function(params.tools, [&](const json & tool) { const auto & function = tool["function"]; schemas.push_back({ {"type", "object"}, @@ -392,108 +369,91 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha return data; } -static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) { +static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) { + fprintf(stderr, "[%s]\n", __func__); auto builtin_tools = json {"wolfram_alpha", "brave_search"}; common_chat_data data; - auto has_python = false; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; - auto add_tool = [&](const json & tool) { + foreach_function(params.tools, [&](const json & tool) { const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; builder.resolve_refs(parameters); - if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) { - has_python = true; - } else { - //"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + - tool_rules.push_back( - builder.add_rule( - name + "-call", - "\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + - builder.add_schema(name + "-args", parameters) + - " \"}\"")); - if (params.tool_choice != "required" && !eagerly_match_any_json) { - data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false}); - // Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B. - // Note that c++11's regex doesn't support partial matches, otherwise it would make - // sense to add support for trigger regexes to the antiprompt mechanism. - data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false}); - data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false}); - data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false}); - data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false}); - } - } - }; - for (const auto & tool : params.tools) { - if (!tool.contains("type")) { - continue; - } - - if (tool["type"] == "code_interpreter") { - builtin_tools.push_back("code_interpreter"); - has_python = true; - } else if (tool["type"] == "function" && tool.contains("function")) { - add_tool(tool); - } + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"<|python_tag|>\" \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + + builder.add_schema(name + "-args", parameters) + + " \"}\"")); + }); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); } - - if (has_python) { - if (uses_python_tag) { - tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*")); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); - } - } else { - add_tool(python_tool); - } - } - - if (params.tool_choice != "required" && eagerly_match_any_json) { - data.grammar_triggers.push_back({"{\"", /* .at_start = */ true}); - data.grammar_triggers.push_back({"{\n\t\"", /* .at_start = */ true}); - data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true}); - data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true}); - } - builder.add_rule("root", string_join(tool_rules, " | ")); }, grammar_options); data.additional_stops.push_back("<|eom_id|>"); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, { {"builtin_tools", builtin_tools}, }); - data.parser = std::make_unique([params, has_python, uses_python_tag](const std::string & input) -> common_chat_msg { - if (has_python && uses_python_tag) { - static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); - std::smatch match; - if (std::regex_search(input, match, python_tag_regex)) { - return { - /* .role = */ "assistant", - /* .content = */ match.prefix().str(), - /* .tool_calls = */ { - { - /* .name = */ "python", - /* .arguments = */ match[1].str(), - /* .id = */ "", - }, - } - }; - } - } - static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": "); + data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { + static std::regex function_regex("<\\|python_tag\\|>\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": "); static std::regex close_regex("\\}"); - return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python && uses_python_tag); + auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true); + return res; }); + fprintf(stderr, "Grammar: %s\n", data.grammar.c_str()); + return data; +} + +static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) { + fprintf(stderr, "[%s]\n", __func__); + common_chat_data data; + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + + // auto add_tool = [&](const json & tool) { + foreach_function(params.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + builder.resolve_refs(parameters); + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"{\" " + // " ( \"\\\"type\\\": \\\"function\\\", \" | space ) " + "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + + builder.add_schema(name + "-args", parameters) + + " \"}\"")); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); + } + }); + + builder.add_rule("root", string_join(tool_rules, " | ")); + }, grammar_options); + data.additional_stops.push_back("<|eom_id|>"); + data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {}); + data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { + static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": "); + static std::regex close_regex("\\}"); + auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true); + return res; + }); + fprintf(stderr, "Grammar: %s\n", data.grammar.c_str()); return data; } static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { + fprintf(stderr, "[%s]\n", __func__); common_chat_data data; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); - foreach_normalized_tool(params.tools, [&](const json & tool) { + foreach_function(params.tools, [&](const json & tool) { const auto & function = tool["function"]; schemas.push_back({ {"type", "object"}, @@ -528,85 +488,92 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_ } static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { + fprintf(stderr, "[%s]\n", __func__); // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar common_chat_data data; - auto has_python = false; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector first_tool_rules; std::vector subsequent_tool_rules; - for (const auto & tool : params.tools) { - if (!tool.contains("type")) { - continue; + foreach_function(params.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto args_rule = builder.add_schema(name + "-args", parameters); + first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); + subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({name, /* .at_start = */ true}); + data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false}); } - if (tool["type"] == "code_interpreter") { - has_python = true; - } else if (tool["type"] == "function" && tool.contains("function")) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - auto args_rule = builder.add_schema(name + "-args", parameters); - first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); - subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\"\\n>>>" + name + "\\n\" " + args_rule)); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({name + "\n", /* .at_start = */ true}); - data.grammar_triggers.push_back({"\n>>>" + name + "\n", /* .at_start = */ false}); - } - } - } + }); auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; - // Note: if there's a python rule, it needs to come last. - auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*"); - if (has_python && params.tool_choice != "required") { - data.grammar_triggers.push_back({"python\n", /* .at_start = */ true}); - data.grammar_triggers.push_back({"\n>>>python\n", /* .at_start = */ false}); - } if (params.parallel_tool_calls) { auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; - builder.add_rule("root", first_rule.empty() ? python_rule : python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : "")); + builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); } else { - builder.add_rule("root", first_rule.empty() ? python_rule : first_rule + (has_python ? " | " + python_rule : "")); + builder.add_rule("root", first_rule); } + }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); - data.parser = std::make_unique([params, has_python](const std::string & input) -> common_chat_msg { + data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); - return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python); + + auto res = parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true); + if (res.content.find("all\n") == 0) { + res.content = res.content.substr(4); + } + return res; }); return data; } static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { + fprintf(stderr, "[%s]\n", __func__); // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt - // TODO: handle tool {type: code_interpreter} as python common_chat_data data; json tools = params.tools.is_null() ? params.tools : json::array(); + std::string python_code_argument_name; + auto has_raw_python = false; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; - auto has_python = false; - for (const auto & tool : params.tools) { - if (!tool.contains("type")) { - continue; - } - if (tool["type"] == "code_interpreter") { - has_python = true; - } else if (tool["type"] == "function" && tool.contains("function")) { - const auto & function = tool["function"]; - std::string name = function["name"]; - if (name == "python" || name == "ipython") { - has_python = true; - } else { - auto parameters = function["parameters"]; - tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); + foreach_function(params.tools, [&](const json & tool) { + const auto & function = tool["function"]; + const auto & parameters = function["parameters"]; + std::string name = function["name"]; + if (name == "python" || name == "ipython") { + if (!parameters.contains("type")) { + throw std::runtime_error("Missing type in python tool"); } + has_raw_python = true; + auto type = parameters.at("type"); + if (type == "object") { + auto properties = parameters.at("properties"); + for (auto it = properties.begin(); it != properties.end(); ++it) { + if (it.value().at("type") == "string") { + if (!python_code_argument_name.empty()) { + throw std::runtime_error("Multiple string arguments found in python tool"); + } + python_code_argument_name = it.key(); + } + } + if (python_code_argument_name.empty()) { + throw std::runtime_error("No string argument found in python tool"); + } + } else if (type != "string") { + throw std::runtime_error("Invalid type in python tool: " + type.dump()); + } + } else { + tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); } - } - if (has_python) { + }); + if (has_raw_python) { tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); if (params.tool_choice != "required") { data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); @@ -620,18 +587,19 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); - data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { + data.parser = std::make_unique([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg { // This version of Functionary still supports the llama 3.1 tool call format for the python tool. static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; if (std::regex_search(input, match, python_tag_regex)) { + auto code = match[1].str(); return { /* .role = */ "assistant", /* .content = */ match.prefix().str(), /* .tool_calls = */ { { /* .name = */ "python", - /* .arguments = */ match[1].str(), + /* .arguments = */ python_code_argument_name.empty() ? code : (json {{python_code_argument_name, code}}).dump(), /* .id = */ "", }, } @@ -639,17 +607,18 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons } static std::regex function_regex(R"()"); static std::regex close_regex(R"()"); - return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, /* has_python= */ false); + return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, has_raw_python); }); return data; } static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { + fprintf(stderr, "[%s]\n", __func__); common_chat_data data; // (content)?({"name": "foo", "arguments": {"a": 1}})* data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; - foreach_normalized_tool(params.tools, [&](const json & tool) { + foreach_function(params.tools, [&](const json & tool) { const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; @@ -719,6 +688,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha } static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) { + fprintf(stderr, "[%s]\n", __func__); common_chat_data data; data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); data.parser = std::make_unique(); @@ -756,13 +726,11 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) { auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos; - // Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name, - // but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon - // as it seems to be outputting some JSON. - // TODO: make this conditional on a very small model (e.g. 1B / 3B). - auto eagerly_match_any_json = false; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2; - - return common_chat_init_llama_3_tool_calls(tmpl, params, uses_python_tag, eagerly_match_any_json); + if (uses_python_tag) { + return common_chat_init_llama_3_1_python_tag_tool_calls(tmpl, params); + } else { + return common_chat_init_llama_3_2_tool_calls(tmpl, params); + } } // if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) { // TODO: Command-R-Plus diff --git a/common/chat-template.hpp b/common/chat-template.hpp index a56cf4d2a..62bd535e0 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -143,28 +143,10 @@ class chat_template { if (adjust_inputs && !tools.is_null() && !supports_code_interpreter_ && has_code_interpreter) { actual_tools = json::array(); for (const auto & tool : tools) { - if (tool.contains("type") && tool.at("type") == "code_interpreter") { - static const auto python_tool = json::parse(R"({ - "type": "function", - "function": { - "name": "python", - "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "The code to run in the ipython interpreter." - } - }, - "required": ["code"] - } - } - })"); - actual_tools.push_back(python_tool); - } else { - actual_tools.push_back(tool); + if (tool.contains("type") && tool.at("type") == "code_interpreter" && !supports_code_interpreter_) { + continue; } + actual_tools.push_back(tool); } } else if (!tools.is_null()) { actual_tools = tools; @@ -295,10 +277,6 @@ class chat_template { if (!tools.is_null()) { auto tools_val = minja::Value(actual_tools); context->set("tools", tools_val); - if (has_code_interpreter && !extra_context.contains("builtin_tools")) { - auto builtin_tools_val = minja::Value(json {"code_interpreter"}); - context->set("builtin_tools", builtin_tools_val); - } } if (!extra_context.is_null()) { for (auto & kv : extra_context.items()) { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e359a3323..a75a7b01f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -333,7 +333,7 @@ struct server_task { if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); } - if (data.contains("json_schema") && !data.contains("grammar")) { + if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { try { auto schema = json_value(data, "json_schema", json::object()); params.sampling.grammar = json_schema_to_grammar(schema); @@ -345,11 +345,6 @@ struct server_task { } } - // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { - throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); - } - if (data.contains("json_schema") && !data.contains("grammar")) { { params.sampling.logit_bias.clear(); params.ignore_eos = json_value(data, "ignore_eos", false); diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 5dde87e47..16ec6fc93 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -221,34 +221,22 @@ PYTHON_TOOL = { } } -CODE_INTEPRETER_TOOL = { - "type": "code_interpreter", -} - @pytest.mark.parametrize("template_name,tool,argument_key", [ - # TODO: fix special handling of python tool for these templates: ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), - ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, None), # "code"), # TODO: fix - ("meetkai-functionary-medium-v3.1", CODE_INTEPRETER_TOOL, None), + ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"), ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"), - ("meetkai-functionary-medium-v3.2", CODE_INTEPRETER_TOOL, None), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", CODE_INTEPRETER_TOOL, None), ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"), - ("meta-llama-Llama-3.2-3B-Instruct", CODE_INTEPRETER_TOOL, None), ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"), - ("mistralai-Mistral-Nemo-Instruct-2407", CODE_INTEPRETER_TOOL, None), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, None), # "code"), # TODO: fix - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", CODE_INTEPRETER_TOOL, None), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"), ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, None), # "code"), # TODO: fix - ("meta-llama-Meta-Llama-3.1-8B-Instruct", CODE_INTEPRETER_TOOL, None), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), ]) def test_completion_with_required_tool(template_name: str, tool: dict, argument_key: str | None): n_predict = 512 @@ -321,29 +309,19 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools: @pytest.mark.slow @pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [ + (PYTHON_TOOL, None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), + (PYTHON_TOOL, None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), + (PYTHON_TOOL, None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), + (PYTHON_TOOL, None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), + (PYTHON_TOOL, None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (PYTHON_TOOL, None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (PYTHON_TOOL, None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None), # TODO: fix these models - # (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), - # (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), - # # (PYTHON_TOOL, {"code": "print(\"Hello, World!\")"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), - # # (CODE_INTEPRETER_TOOL, {"code": "print(\"Hello, World!\")"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), - # (PYTHON_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), - # (CODE_INTEPRETER_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), - # (PYTHON_TOOL, {"code": "print(\"hello world\")"}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), - # (CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), - (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), - (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), - (PYTHON_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), - (PYTHON_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch-Hermes-2-Pro-Llama-3-8B", "tool_use")), - (PYTHON_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), - (PYTHON_TOOL, {"code": "print('Hello, World!')\n"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None), - (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')\n"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", ("mistralai-Mistral-Nemo-Instruct-2407", None)), + (PYTHON_TOOL, '{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), + # (PYTHON_TOOL, None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + # (PYTHON_TOOL, None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), ]) -def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): +def test_hello_world_tool_call(tool: dict, expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): global server server.n_slots = 2 server.jinja = True @@ -377,9 +355,15 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st assert tool["function"]["name"] == tool_call["function"]["name"] elif tool["type"] == "code_interpreter": assert re.match('i?python', tool_call["function"]["name"]) - actual_arguments = json.loads(tool_call["function"]["arguments"]) - code = actual_arguments["code"] - assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}' + actual_arguments = tool_call["function"]["arguments"] + if expected_arguments is not None: + assert actual_arguments == expected_arguments + else: + actual_arguments = json.loads(actual_arguments) + assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" + code = actual_arguments["code"] + assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" + assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}' def test_logprobs():