nits
This commit is contained in:
parent
f7078cab36
commit
ca0c837b6a
5 changed files with 50 additions and 47 deletions
6
Makefile
6
Makefile
|
@ -52,6 +52,7 @@ TEST_TARGETS = \
|
|||
tests/test-arg-parser \
|
||||
tests/test-autorelease \
|
||||
tests/test-backend-ops \
|
||||
tests/test-chat-handler \
|
||||
tests/test-chat-template \
|
||||
tests/test-double-float \
|
||||
tests/test-grammar-integration \
|
||||
|
@ -64,7 +65,6 @@ TEST_TARGETS = \
|
|||
tests/test-quantize-perf \
|
||||
tests/test-rope \
|
||||
tests/test-sampling \
|
||||
tests/test-tool-call \
|
||||
tests/test-tokenizer-0 \
|
||||
tests/test-tokenizer-1-bpe \
|
||||
tests/test-tokenizer-1-spm
|
||||
|
@ -984,8 +984,8 @@ OBJ_COMMON = \
|
|||
$(DIR_COMMON)/ngram-cache.o \
|
||||
$(DIR_COMMON)/sampling.o \
|
||||
$(DIR_COMMON)/speculative.o \
|
||||
$(DIR_COMMON)/chat-handler.o \
|
||||
$(DIR_COMMON)/build-info.o \
|
||||
$(DIR_COMMON)/tool-call.o \
|
||||
$(DIR_COMMON)/json-schema-to-grammar.o
|
||||
|
||||
OBJ_ALL = $(OBJ_GGML) $(OBJ_LLAMA) $(OBJ_COMMON)
|
||||
|
@ -1475,7 +1475,7 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
|
|||
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-tool-call: tests/test-tool-call.cpp \
|
||||
tests/test-chat-handler: tests/test-chat-handler.cpp \
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
|
|
@ -58,7 +58,7 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
|||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||
* Aggregates the prefix, suffix and in-between text into the content.
|
||||
*/
|
||||
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
|
||||
static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool has_python) {
|
||||
std::smatch match;
|
||||
|
||||
common_chat_msg result;
|
||||
|
@ -102,7 +102,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
|
|||
|
||||
json arguments;
|
||||
if (!parse_json(it, end, arguments)) {
|
||||
if (name == "python" && std::regex_match("", close_regex)) {
|
||||
if (has_python && name == "python" && std::regex_match("", close_regex)) {
|
||||
std::string src(it, end);
|
||||
result.tool_calls.push_back({name, src, /* id= */ ""});
|
||||
break;
|
||||
|
@ -232,7 +232,7 @@ static void foreach_normalized_tool(const json & tools, const std::function<void
|
|||
}
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
fn(python_tool);
|
||||
} else {
|
||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||
fn(tool);
|
||||
}
|
||||
}
|
||||
|
@ -400,16 +400,7 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
|
|||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
|
||||
for (const auto & tool : params.tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
builtin_tools.push_back("code_interpreter");
|
||||
has_python = true;
|
||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||
auto add_tool = [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
|
@ -435,14 +426,29 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
|
|||
data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false});
|
||||
}
|
||||
}
|
||||
};
|
||||
for (const auto & tool : params.tools) {
|
||||
if (!tool.contains("type")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tool["type"] == "code_interpreter") {
|
||||
builtin_tools.push_back("code_interpreter");
|
||||
has_python = true;
|
||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||
add_tool(tool);
|
||||
}
|
||||
}
|
||||
|
||||
if (has_python && uses_python_tag) {
|
||||
if (has_python) {
|
||||
if (uses_python_tag) {
|
||||
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
}
|
||||
} else {
|
||||
add_tool(python_tool);
|
||||
}
|
||||
}
|
||||
|
||||
if (params.tool_choice != "required" && eagerly_match_any_json) {
|
||||
|
@ -478,7 +484,7 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
|
|||
}
|
||||
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python && uses_python_tag);
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
@ -568,7 +574,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
|
|||
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python](const std::string & input) -> common_chat_msg {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true, has_python);
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
@ -633,7 +639,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
|
|||
}
|
||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||
static std::regex close_regex(R"(</function>)");
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false);
|
||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false, /* has_python= */ false);
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace minja {
|
|||
class chat_template {
|
||||
public:
|
||||
|
||||
// private:
|
||||
private:
|
||||
bool supports_tools_ = true;
|
||||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||
|
@ -147,7 +147,7 @@ class chat_template {
|
|||
static const auto python_tool = json::parse(R"({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ipython",
|
||||
"name": "python",
|
||||
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
|
@ -284,9 +284,6 @@ class chat_template {
|
|||
} else {
|
||||
actual_messages = messages;
|
||||
}
|
||||
// if (adjust_inputs) {
|
||||
// fprintf(stderr, "Messages: %s\n", actual_messages.dump(2).c_str());
|
||||
// }
|
||||
|
||||
auto context = minja::Context::make(json({
|
||||
{"messages", actual_messages},
|
||||
|
|
|
@ -1118,7 +1118,7 @@ curl http://localhost:8080/v1/chat/completions \
|
|||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ipython",
|
||||
"name": "python",
|
||||
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
|
@ -1155,7 +1155,7 @@ curl http://localhost:8080/v1/chat/completions \
|
|||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "ipython",
|
||||
"name": "python",
|
||||
"arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}"
|
||||
}
|
||||
],
|
||||
|
|
|
@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) {
|
|||
}
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
||||
fprintf(stderr, "\n%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
|
||||
rules.clear();
|
||||
return false;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue