From add91241150afc64d2b26ba487bff53e4f25a9b4 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 27 Jan 2025 20:13:09 +0000 Subject: [PATCH] fix test-chat-handler grammar tests --- common/chat-handler.cpp | 33 ++++++++++--- common/chat-template.hpp | 15 +++--- tests/test-chat-handler.cpp | 93 ++++++++++++++++++------------------- 3 files changed, 82 insertions(+), 59 deletions(-) diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index feb2245d7..d5f235f21 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -363,6 +363,7 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) { fprintf(stderr, "[%s]\n", __func__); + // TODO: get from request body. auto builtin_tools = json {"wolfram_alpha", "brave_search"}; common_chat_data data; @@ -377,10 +378,16 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c tool_rules.push_back( builder.add_rule( name + "-call", - "\"<|python_tag|>\" \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + + "\"{\" " + // " ( \"\\\"type\\\": \\\"function\\\", \" | space ) " + "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + builder.add_schema(name + "-args", parameters) + " \"}\"")); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); + } }); + tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*")); if (params.tool_choice != "required") { data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); } @@ -391,11 +398,27 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c {"builtin_tools", builtin_tools}, }); data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { - static std::regex function_regex("<\\|python_tag\\|>\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": "); + static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": "); static std::regex close_regex("\\}"); + static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\((.*)\)"); + + std::smatch match; + if (std::regex_match(input, match, builtin_call_regex)) { + auto arguments = json::parse("[" + match[2].str() + "]"); + return { + /* .role = */ "assistant", + /* .content = */ match.prefix().str(), + /* .tool_calls = */ { + { + /* .name = */ match[1], + /* .arguments = */ arguments.dump(), + /* .id = */ "", + }, + }, + }; + } return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true); }); - fprintf(stderr, "Grammar: %s\n", data.grammar.c_str()); return data; } @@ -435,7 +458,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_ auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true); return res; }); - fprintf(stderr, "Grammar: %s\n", data.grammar.c_str()); return data; } @@ -590,9 +612,8 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons } else if (type != "string") { throw std::runtime_error("Invalid type in python tool: " + type.dump()); } - } else { - tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); } + tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); }); if (has_raw_python) { tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); diff --git a/common/chat-template.hpp b/common/chat-template.hpp index 89ce933a0..f528875f7 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -129,6 +129,7 @@ class chat_template { bool supports_tools() const { return supports_tools_; } bool supports_tool_calls() const { return supports_tool_calls_; } bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; } + bool requires_object_arguments() const { return requires_object_arguments_; } std::string apply( const nlohmann::ordered_json & messages, @@ -201,12 +202,14 @@ class chat_template { for (auto & tool_call : message.at("tool_calls")) { if (tool_call["type"] == "function") { auto & function = tool_call.at("function"); - std::string arguments = function.at("arguments"); - try { - function["arguments"] = json::parse(arguments); - } catch (const std::exception & ecvt) { - fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what()); - function["arguments"] = arguments; + auto & arguments = function.at("arguments"); + if (arguments.is_string()) { + try { + arguments = json::parse(arguments.get()); + } catch (const std::exception & ecvt) { + fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what()); + arguments = arguments; + } } } } diff --git a/tests/test-chat-handler.cpp b/tests/test-chat-handler.cpp index 88366c685..5e9db450e 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat-handler.cpp @@ -72,32 +72,17 @@ static std::string dump(const json & j) { return minja::Value(j).dump(-1, /* to_json= */ true); } -static void assert_msg_equals(const common_chat_msg & result, const std::string & expected_content, const json & expected_tool_calls) { - assert_equals(expected_content, result.content); - auto tool_calls = json::array(); - for (const auto & tc : result.tool_calls) { - auto arguments = tc.arguments; - try { - arguments = dump(json::parse(arguments)); - } catch (const std::exception & e) { - // ignore - } - auto tool_call = json { - {"type", "function"}, - {"function", { - {"arguments", arguments}, - {"name", tc.name}, - }}, - }; - if (!tc.id.empty()) { - tool_call["id"] = tc.id; - } - tool_calls.push_back(tool_call); +static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) { + assert_equals(expected.role, actual.role); + assert_equals(expected.content, actual.content); + assert_equals(expected.tool_calls.size(), actual.tool_calls.size()); + for (size_t i = 0; i < expected.tool_calls.size(); i++) { + const auto & expected_tool_call = expected.tool_calls[i]; + const auto & actual_tool_call = actual.tool_calls[i]; + assert_equals(expected_tool_call.name, actual_tool_call.name); + assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments))); + assert_equals(expected_tool_call.id, actual_tool_call.id); } - // Reparse / dump w/ non-ordered JSON variant. - auto expected = nlohmann::json::parse(expected_tool_calls.dump()).dump(); - auto actual = nlohmann::json::parse(tool_calls.dump()).dump(); - assert_equals(expected, actual); } const auto special_function_tool = json::parse(R"({ @@ -373,7 +358,19 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) { // auto tool_call_style = common_tool_call_style_detect(tmpl); - auto & tool_calls = tool_calling_message.at("tool_calls"); + common_chat_msg expected_msg { + "assistant", + "", + {}, + }; + for (const auto & tc : tool_calling_message.at("tool_calls")) { + const auto & arguments = tc.at("function").at("arguments"); + expected_msg.tool_calls.push_back({ + tc.at("function").at("name").get(), + arguments.is_string() ? arguments.get() : arguments.dump(), + tc.contains("id") ? tc.at("id").get() : "", + }); + } // Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false, // get the diff and try and parse it w/ the grammar. @@ -398,12 +395,12 @@ static void test_template(const common_chat_template & tmpl, const std::vectorparse_final(full_delta); - assert_msg_equals(msg, "", tool_calls); + assert_msg_equals(expected_msg, msg); auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, { {"role", "assistant"}, {"content", {}}, - {"tool_calls", tool_calls} + {"tool_calls", tool_calling_message.at("tool_calls")} }, tools); if (!match_string(content_less_delta, grammar.get())) { throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar); @@ -433,7 +430,9 @@ static void test_grammars() { {"type", "function"}, {"function", { {"name", "python"}, - {"arguments", "print('hey')"} + {"arguments", { + {"code", "print('hey')"}, + }}, }}, }}} }; @@ -442,12 +441,12 @@ static void test_grammars() { const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); test_template(tmpl, { "" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true); } - // { - // const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""); - // // assert_equals(tmpl.requires_object_arguments_, true); - // test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); - // test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools); - // } + { + const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""); + // assert_equals(tmpl.requires_object_arguments_, true); + test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); + test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools); + } { const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); @@ -456,22 +455,22 @@ static void test_grammars() { const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""); test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); } - // { - // const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "", ""); - // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - // } - // { - // const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "", ""); - // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools); - // } + { + const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "", ""); + test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "", ""); + test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools); + } { const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); } - // { - // const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""); - // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - // } + { + const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""); + test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + } { const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);