Fix Llama 3.1 (incl. constrained builtin tools e.g. <|python_tag|>foo.call(arg=vallue))

This commit is contained in:
ochafik 2025-01-28 01:04:06 +00:00
parent 2d607f1a68
commit ef9efc9ed3
3 changed files with 116 additions and 32 deletions

View file

@ -207,7 +207,6 @@ static void foreach_function(const json & tools, const std::function<void(const
}
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
auto tool_call_schemas = json::array();
@ -318,7 +317,6 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
}
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@ -358,25 +356,71 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "mistral nemo tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
});
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
});
return data;
}
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
// TODO: get from request body.
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
common_chat_data data;
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
}
const auto & parameters_properties = parameters.at("properties");
const auto & parameters_required = parameters.at("required");
for (const auto & prop : expected_properties) {
if (!parameters_properties.contains(prop)) {
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
}
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
}
}
if (parameters_properties.size() != expected_properties.size()) {
throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", "));
}
}
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
auto builtin_tools = json::array();
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
if (name == "wolfram_alpha") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "web_search" || name == "brave_search") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "python" || name == "code_interpreter") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
expect_tool_parameters(name, parameters, {"code"});
} else {
return false;
}
std::vector<std::string> kvs;
for (const auto & [key, value] : parameters.at("properties").items()) {
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
builtin_tools.push_back(name);
return true;
};
foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
if (handle_builtin_tool(name, parameters)) {
return;
}
builder.resolve_refs(parameters);
tool_rules.push_back(
builder.add_rule(
@ -388,30 +432,42 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
" \"}\""));
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
});
tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*"));
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
if (!builtin_tools.empty()) {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
builder.add_rule("root", string_join(tool_rules, " | "));
}, grammar_options);
data.additional_stops.push_back("<|eom_id|>");
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {
{"builtin_tools", builtin_tools},
{"tools_in_user_message", false},
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
});
data.format = "llama 3.1 tool calls";
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex close_regex("\\}");
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\((.*)\)");
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
std::smatch match;
if (std::regex_match(input, match, builtin_call_regex)) {
auto arguments = json::parse("[" + match[2].str() + "]");
auto name = match[1].str();
auto raw_args = match[2].str();
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
auto it_eq = raw_args.find('=');
auto arg_name = raw_args.substr(0, it_eq);
auto arg_value_str = raw_args.substr(it_eq + 1);
auto arg_value = json::parse(arg_value_str);
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ match[1],
/* .arguments = */ arguments.dump(),
/* .arguments = */ (json {
{arg_name, arg_value},
}).dump(),
/* .id = */ "",
},
},
@ -423,7 +479,6 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
}
static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
@ -462,7 +517,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
}
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@ -490,7 +544,6 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
}
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.grammar_lazy = params.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@ -529,7 +582,6 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
}
static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_data data;
@ -574,7 +626,6 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
}
static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_data data;
@ -651,7 +702,6 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
}
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
data.grammar_lazy = params.tool_choice != "required";
@ -705,9 +755,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
if (!parse_json(it, end, call)) {
throw std::runtime_error("Failed to parse json tool call");
}
const auto & arguments = call["arguments"];
result.tool_calls.push_back({
call["name"],
call["arguments"].dump(),
arguments.dump(),
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
/* id= */ "",
});
rit = {it, end, middle_pattern};
@ -734,7 +786,6 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
}
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
common_chat_data data;
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
data.format = "content-only";

View file

@ -63,6 +63,8 @@ WEATHER_TOOL = {
@pytest.mark.parametrize("template_name,tool,argument_key", [
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
@ -78,8 +80,6 @@ WEATHER_TOOL = {
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
# TODO: fix these
# ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
# ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
])
def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
n_predict = 512
@ -118,6 +118,8 @@ def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argu
@pytest.mark.slow
@pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [
(TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
(PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
@ -139,8 +141,6 @@ def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argu
# TODO: fix these
# (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
# (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
# (TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
# (PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
])
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
n_predict = 512
@ -218,6 +218,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
@pytest.mark.slow
@pytest.mark.parametrize("hf_repo,hf_file,template_override", [
("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
@ -229,7 +230,6 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
# TODO: fix these
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
# ("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
])
def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
global server
@ -267,6 +267,7 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
@pytest.mark.slow
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
@ -277,7 +278,6 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
# TODO: fix these
# ('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
])
def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):

View file

@ -119,7 +119,25 @@ const auto python_tool = json::parse(R"({
}
}
})");
const auto code_interpreter_tool = json::parse(R"({
"type": "function",
"function": {
"name": "code_interpreter",
"description": "an ipython interpreter",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "Python code to execute."
}
},
"required": ["code"]
}
}
})");
const json tools = {special_function_tool, python_tool};
const json llama_3_1_tools = {special_function_tool, code_interpreter_tool};
// static void test_parsing() {
// json request = {
@ -427,6 +445,19 @@ static void test_grammars() {
}},
}}}
};
auto code_interpreter_tool_call_message = json {
{"role", "assistant"},
{"content", {}},
{"tool_calls", json {{
{"type", "function"},
{"function", {
{"name", "code_interpreter"},
{"arguments", {
{"code", "print('hey')"},
}},
}},
}}}
};
common_chat_params no_tools_params;
@ -494,10 +525,12 @@ static void test_grammars() {
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
test_template(tmpl, end_tokens, text_message, tools);
// assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
// test_template(tmpl, end_tokens, text_message, tools);
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools);
test_template(tmpl, end_tokens, python_tool_call_message, tools);
test_template(tmpl, end_tokens, tool_call_message, tools);
test_template(tmpl, end_tokens, python_tool_call_message, tools);
test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");