From ca0c837b6a7b9883204b9c4baba7598f9ef45d88 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 27 Jan 2025 01:08:29 +0000 Subject: [PATCH] nits --- Makefile | 6 +-- common/chat-handler.cpp | 78 +++++++++++++++++++++------------------ common/chat-template.hpp | 7 +--- examples/server/README.md | 4 +- src/llama-grammar.cpp | 2 +- 5 files changed, 50 insertions(+), 47 deletions(-) diff --git a/Makefile b/Makefile index ed04dc176..529fc6313 100644 --- a/Makefile +++ b/Makefile @@ -52,6 +52,7 @@ TEST_TARGETS = \ tests/test-arg-parser \ tests/test-autorelease \ tests/test-backend-ops \ + tests/test-chat-handler \ tests/test-chat-template \ tests/test-double-float \ tests/test-grammar-integration \ @@ -64,7 +65,6 @@ TEST_TARGETS = \ tests/test-quantize-perf \ tests/test-rope \ tests/test-sampling \ - tests/test-tool-call \ tests/test-tokenizer-0 \ tests/test-tokenizer-1-bpe \ tests/test-tokenizer-1-spm @@ -984,8 +984,8 @@ OBJ_COMMON = \ $(DIR_COMMON)/ngram-cache.o \ $(DIR_COMMON)/sampling.o \ $(DIR_COMMON)/speculative.o \ + $(DIR_COMMON)/chat-handler.o \ $(DIR_COMMON)/build-info.o \ - $(DIR_COMMON)/tool-call.o \ $(DIR_COMMON)/json-schema-to-grammar.o OBJ_ALL = $(OBJ_GGML) $(OBJ_LLAMA) $(OBJ_COMMON) @@ -1475,7 +1475,7 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \ $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -tests/test-tool-call: tests/test-tool-call.cpp \ +tests/test-chat-handler: tests/test-chat-handler.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index abbabe069..511fa1aef 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -58,7 +58,7 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. * Aggregates the prefix, suffix and in-between text into the content. */ -static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) { +static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool has_python) { std::smatch match; common_chat_msg result; @@ -102,7 +102,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri json arguments; if (!parse_json(it, end, arguments)) { - if (name == "python" && std::regex_match("", close_regex)) { + if (has_python && name == "python" && std::regex_match("", close_regex)) { std::string src(it, end); result.tool_calls.push_back({name, src, /* id= */ ""}); break; @@ -232,7 +232,7 @@ static void foreach_normalized_tool(const json & tools, const std::function tool_rules; - + auto add_tool = [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + builder.resolve_refs(parameters); + if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) { + has_python = true; + } else { + //"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + + builder.add_schema(name + "-args", parameters) + + " \"}\"")); + if (params.tool_choice != "required" && !eagerly_match_any_json) { + data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false}); + // Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B. + // Note that c++11's regex doesn't support partial matches, otherwise it would make + // sense to add support for trigger regexes to the antiprompt mechanism. + data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false}); + data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false}); + data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false}); + data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false}); + } + } + }; for (const auto & tool : params.tools) { if (!tool.contains("type")) { continue; @@ -410,38 +436,18 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te builtin_tools.push_back("code_interpreter"); has_python = true; } else if (tool["type"] == "function" && tool.contains("function")) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - builder.resolve_refs(parameters); - if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) { - has_python = true; - } else { - //"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + - tool_rules.push_back( - builder.add_rule( - name + "-call", - "\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + - builder.add_schema(name + "-args", parameters) + - " \"}\"")); - if (params.tool_choice != "required" && !eagerly_match_any_json) { - data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false}); - // Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B. - // Note that c++11's regex doesn't support partial matches, otherwise it would make - // sense to add support for trigger regexes to the antiprompt mechanism. - data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false}); - data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false}); - data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false}); - data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false}); - } - } + add_tool(tool); } } - if (has_python && uses_python_tag) { - tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*")); - if (params.tool_choice != "required") { - data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); + if (has_python) { + if (uses_python_tag) { + tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*")); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); + } + } else { + add_tool(python_tool); } } @@ -478,7 +484,7 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te } static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": "); static std::regex close_regex("\\}"); - return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true); + return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python && uses_python_tag); }); return data; } @@ -568,7 +574,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const data.parser = std::make_unique([params, has_python](const std::string & input) -> common_chat_msg { static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); - return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true); + return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python); }); return data; } @@ -633,7 +639,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons } static std::regex function_regex(R"()"); static std::regex close_regex(R"()"); - return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false); + return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, /* has_python= */ false); }); return data; } diff --git a/common/chat-template.hpp b/common/chat-template.hpp index e0a9a1c56..a56cf4d2a 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -20,7 +20,7 @@ namespace minja { class chat_template { public: -// private: + private: bool supports_tools_ = true; // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. // Most other templates (and OpenAI's API) expect the arguments object to be stringified. @@ -147,7 +147,7 @@ class chat_template { static const auto python_tool = json::parse(R"({ "type": "function", "function": { - "name": "ipython", + "name": "python", "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", "parameters": { "type": "object", @@ -284,9 +284,6 @@ class chat_template { } else { actual_messages = messages; } - // if (adjust_inputs) { - // fprintf(stderr, "Messages: %s\n", actual_messages.dump(2).c_str()); - // } auto context = minja::Context::make(json({ {"messages", actual_messages}, diff --git a/examples/server/README.md b/examples/server/README.md index 89020bccb..7272204cd 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -1118,7 +1118,7 @@ curl http://localhost:8080/v1/chat/completions \ { "type": "function", "function": { - "name": "ipython", + "name": "python", "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", "parameters": { "type": "object", @@ -1155,7 +1155,7 @@ curl http://localhost:8080/v1/chat/completions \ "content": null, "tool_calls": [ { - "name": "ipython", + "name": "python", "arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}" } ], diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 2eae29bb9..bb2d3f3c4 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) { } } } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + fprintf(stderr, "\n%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); rules.clear(); return false; }