This commit is contained in:
ochafik 2025-01-27 01:08:29 +00:00
parent f7078cab36
commit ca0c837b6a
5 changed files with 50 additions and 47 deletions

View file

@ -52,6 +52,7 @@ TEST_TARGETS = \
tests/test-arg-parser \ tests/test-arg-parser \
tests/test-autorelease \ tests/test-autorelease \
tests/test-backend-ops \ tests/test-backend-ops \
tests/test-chat-handler \
tests/test-chat-template \ tests/test-chat-template \
tests/test-double-float \ tests/test-double-float \
tests/test-grammar-integration \ tests/test-grammar-integration \
@ -64,7 +65,6 @@ TEST_TARGETS = \
tests/test-quantize-perf \ tests/test-quantize-perf \
tests/test-rope \ tests/test-rope \
tests/test-sampling \ tests/test-sampling \
tests/test-tool-call \
tests/test-tokenizer-0 \ tests/test-tokenizer-0 \
tests/test-tokenizer-1-bpe \ tests/test-tokenizer-1-bpe \
tests/test-tokenizer-1-spm tests/test-tokenizer-1-spm
@ -984,8 +984,8 @@ OBJ_COMMON = \
$(DIR_COMMON)/ngram-cache.o \ $(DIR_COMMON)/ngram-cache.o \
$(DIR_COMMON)/sampling.o \ $(DIR_COMMON)/sampling.o \
$(DIR_COMMON)/speculative.o \ $(DIR_COMMON)/speculative.o \
$(DIR_COMMON)/chat-handler.o \
$(DIR_COMMON)/build-info.o \ $(DIR_COMMON)/build-info.o \
$(DIR_COMMON)/tool-call.o \
$(DIR_COMMON)/json-schema-to-grammar.o $(DIR_COMMON)/json-schema-to-grammar.o
OBJ_ALL = $(OBJ_GGML) $(OBJ_LLAMA) $(OBJ_COMMON) OBJ_ALL = $(OBJ_GGML) $(OBJ_LLAMA) $(OBJ_COMMON)
@ -1475,7 +1475,7 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-tool-call: tests/test-tool-call.cpp \ tests/test-chat-handler: tests/test-chat-handler.cpp \
$(OBJ_ALL) $(OBJ_ALL)
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View file

@ -58,7 +58,7 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content. * Aggregates the prefix, suffix and in-between text into the content.
*/ */
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) { static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool has_python) {
std::smatch match; std::smatch match;
common_chat_msg result; common_chat_msg result;
@ -102,7 +102,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
json arguments; json arguments;
if (!parse_json(it, end, arguments)) { if (!parse_json(it, end, arguments)) {
if (name == "python" && std::regex_match("", close_regex)) { if (has_python && name == "python" && std::regex_match("", close_regex)) {
std::string src(it, end); std::string src(it, end);
result.tool_calls.push_back({name, src, /* id= */ ""}); result.tool_calls.push_back({name, src, /* id= */ ""});
break; break;
@ -232,7 +232,7 @@ static void foreach_normalized_tool(const json & tools, const std::function<void
} }
if (tool["type"] == "code_interpreter") { if (tool["type"] == "code_interpreter") {
fn(python_tool); fn(python_tool);
} else { } else if (tool["type"] == "function" && tool.contains("function")) {
fn(tool); fn(tool);
} }
} }
@ -400,7 +400,33 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
auto add_tool = [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) {
has_python = true;
} else {
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (params.tool_choice != "required" && !eagerly_match_any_json) {
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false});
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
// Note that c++11's regex doesn't support partial matches, otherwise it would make
// sense to add support for trigger regexes to the antiprompt mechanism.
data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false});
}
}
};
for (const auto & tool : params.tools) { for (const auto & tool : params.tools) {
if (!tool.contains("type")) { if (!tool.contains("type")) {
continue; continue;
@ -410,38 +436,18 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
builtin_tools.push_back("code_interpreter"); builtin_tools.push_back("code_interpreter");
has_python = true; has_python = true;
} else if (tool["type"] == "function" && tool.contains("function")) { } else if (tool["type"] == "function" && tool.contains("function")) {
const auto & function = tool["function"]; add_tool(tool);
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) {
has_python = true;
} else {
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (params.tool_choice != "required" && !eagerly_match_any_json) {
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false});
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
// Note that c++11's regex doesn't support partial matches, otherwise it would make
// sense to add support for trigger regexes to the antiprompt mechanism.
data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false});
}
}
} }
} }
if (has_python && uses_python_tag) { if (has_python) {
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*")); if (uses_python_tag) {
if (params.tool_choice != "required") { tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
} else {
add_tool(python_tool);
} }
} }
@ -478,7 +484,7 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
} }
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": "); static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex close_regex("\\}"); static std::regex close_regex("\\}");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true); return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python && uses_python_tag);
}); });
return data; return data;
} }
@ -568,7 +574,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python](const std::string & input) -> common_chat_msg { data.parser = std::make_unique<monolithic_chat_parser>([params, has_python](const std::string & input) -> common_chat_msg {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))"); static std::regex close_regex(R"($|(?=>>>))");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true); return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python);
}); });
return data; return data;
} }
@ -633,7 +639,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
} }
static std::regex function_regex(R"(<function=(\w+)>)"); static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)"); static std::regex close_regex(R"(</function>)");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false); return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, /* has_python= */ false);
}); });
return data; return data;
} }

View file

@ -20,7 +20,7 @@ namespace minja {
class chat_template { class chat_template {
public: public:
// private: private:
bool supports_tools_ = true; bool supports_tools_ = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified. // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
@ -147,7 +147,7 @@ class chat_template {
static const auto python_tool = json::parse(R"({ static const auto python_tool = json::parse(R"({
"type": "function", "type": "function",
"function": { "function": {
"name": "ipython", "name": "python",
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters": { "parameters": {
"type": "object", "type": "object",
@ -284,9 +284,6 @@ class chat_template {
} else { } else {
actual_messages = messages; actual_messages = messages;
} }
// if (adjust_inputs) {
// fprintf(stderr, "Messages: %s\n", actual_messages.dump(2).c_str());
// }
auto context = minja::Context::make(json({ auto context = minja::Context::make(json({
{"messages", actual_messages}, {"messages", actual_messages},

View file

@ -1118,7 +1118,7 @@ curl http://localhost:8080/v1/chat/completions \
{ {
"type": "function", "type": "function",
"function": { "function": {
"name": "ipython", "name": "python",
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters": { "parameters": {
"type": "object", "type": "object",
@ -1155,7 +1155,7 @@ curl http://localhost:8080/v1/chat/completions \
"content": null, "content": null,
"tool_calls": [ "tool_calls": [
{ {
"name": "ipython", "name": "python",
"arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}" "arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}"
} }
], ],

View file

@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) {
} }
} }
} catch (const std::exception & err) { } catch (const std::exception & err) {
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); fprintf(stderr, "\n%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
rules.clear(); rules.clear();
return false; return false;
} }