This commit is contained in:
ochafik 2025-01-27 01:08:29 +00:00
parent f7078cab36
commit ca0c837b6a
5 changed files with 50 additions and 47 deletions

View file

@ -52,6 +52,7 @@ TEST_TARGETS = \
tests/test-arg-parser \
tests/test-autorelease \
tests/test-backend-ops \
tests/test-chat-handler \
tests/test-chat-template \
tests/test-double-float \
tests/test-grammar-integration \
@ -64,7 +65,6 @@ TEST_TARGETS = \
tests/test-quantize-perf \
tests/test-rope \
tests/test-sampling \
tests/test-tool-call \
tests/test-tokenizer-0 \
tests/test-tokenizer-1-bpe \
tests/test-tokenizer-1-spm
@ -984,8 +984,8 @@ OBJ_COMMON = \
$(DIR_COMMON)/ngram-cache.o \
$(DIR_COMMON)/sampling.o \
$(DIR_COMMON)/speculative.o \
$(DIR_COMMON)/chat-handler.o \
$(DIR_COMMON)/build-info.o \
$(DIR_COMMON)/tool-call.o \
$(DIR_COMMON)/json-schema-to-grammar.o
OBJ_ALL = $(OBJ_GGML) $(OBJ_LLAMA) $(OBJ_COMMON)
@ -1475,7 +1475,7 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-tool-call: tests/test-tool-call.cpp \
tests/test-chat-handler: tests/test-chat-handler.cpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View file

@ -58,7 +58,7 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content.
*/
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool has_python) {
std::smatch match;
common_chat_msg result;
@ -102,7 +102,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
json arguments;
if (!parse_json(it, end, arguments)) {
if (name == "python" && std::regex_match("", close_regex)) {
if (has_python && name == "python" && std::regex_match("", close_regex)) {
std::string src(it, end);
result.tool_calls.push_back({name, src, /* id= */ ""});
break;
@ -232,7 +232,7 @@ static void foreach_normalized_tool(const json & tools, const std::function<void
}
if (tool["type"] == "code_interpreter") {
fn(python_tool);
} else {
} else if (tool["type"] == "function" && tool.contains("function")) {
fn(tool);
}
}
@ -400,7 +400,33 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
auto add_tool = [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) {
has_python = true;
} else {
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (params.tool_choice != "required" && !eagerly_match_any_json) {
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false});
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
// Note that c++11's regex doesn't support partial matches, otherwise it would make
// sense to add support for trigger regexes to the antiprompt mechanism.
data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false});
}
}
};
for (const auto & tool : params.tools) {
if (!tool.contains("type")) {
continue;
@ -410,38 +436,18 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
builtin_tools.push_back("code_interpreter");
has_python = true;
} else if (tool["type"] == "function" && tool.contains("function")) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) {
has_python = true;
} else {
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (params.tool_choice != "required" && !eagerly_match_any_json) {
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false});
// Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B.
// Note that c++11's regex doesn't support partial matches, otherwise it would make
// sense to add support for trigger regexes to the antiprompt mechanism.
data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false});
data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false});
}
}
add_tool(tool);
}
}
if (has_python && uses_python_tag) {
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
if (has_python) {
if (uses_python_tag) {
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
} else {
add_tool(python_tool);
}
}
@ -478,7 +484,7 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
}
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex close_regex("\\}");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python && uses_python_tag);
});
return data;
}
@ -568,7 +574,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python](const std::string & input) -> common_chat_msg {
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python);
});
return data;
}
@ -633,7 +639,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
}
static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)");
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false);
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, /* has_python= */ false);
});
return data;
}

View file

@ -20,7 +20,7 @@ namespace minja {
class chat_template {
public:
// private:
private:
bool supports_tools_ = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
@ -147,7 +147,7 @@ class chat_template {
static const auto python_tool = json::parse(R"({
"type": "function",
"function": {
"name": "ipython",
"name": "python",
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters": {
"type": "object",
@ -284,9 +284,6 @@ class chat_template {
} else {
actual_messages = messages;
}
// if (adjust_inputs) {
// fprintf(stderr, "Messages: %s\n", actual_messages.dump(2).c_str());
// }
auto context = minja::Context::make(json({
{"messages", actual_messages},

View file

@ -1118,7 +1118,7 @@ curl http://localhost:8080/v1/chat/completions \
{
"type": "function",
"function": {
"name": "ipython",
"name": "python",
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters": {
"type": "object",
@ -1155,7 +1155,7 @@ curl http://localhost:8080/v1/chat/completions \
"content": null,
"tool_calls": [
{
"name": "ipython",
"name": "python",
"arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}"
}
],

View file

@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) {
}
}
} catch (const std::exception & err) {
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
fprintf(stderr, "\n%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
rules.clear();
return false;
}