Fix Llama 3.1 (incl. constrained builtin tools e.g. <|python_tag|>foo.call(arg=vallue)
)
This commit is contained in:
parent
2d607f1a68
commit
ef9efc9ed3
3 changed files with 116 additions and 32 deletions
|
@ -207,7 +207,6 @@ static void foreach_function(const json & tools, const std::function<void(const
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
|
||||
auto tool_call_schemas = json::array();
|
||||
|
@ -318,7 +317,6 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
|
@ -358,25 +356,71 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
|
|||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "mistral nemo tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
|
||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||
});
|
||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
// TODO: get from request body.
|
||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||
common_chat_data data;
|
||||
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
|
||||
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
|
||||
}
|
||||
const auto & parameters_properties = parameters.at("properties");
|
||||
const auto & parameters_required = parameters.at("required");
|
||||
for (const auto & prop : expected_properties) {
|
||||
if (!parameters_properties.contains(prop)) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
|
||||
}
|
||||
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
|
||||
}
|
||||
}
|
||||
if (parameters_properties.size() != expected_properties.size()) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", "));
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
auto builtin_tools = json::array();
|
||||
common_chat_data data;
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
||||
if (name == "wolfram_alpha") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
||||
expect_tool_parameters(name, parameters, {"query"});
|
||||
} else if (name == "web_search" || name == "brave_search") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
||||
expect_tool_parameters(name, parameters, {"query"});
|
||||
} else if (name == "python" || name == "code_interpreter") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
||||
expect_tool_parameters(name, parameters, {"code"});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> kvs;
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
|
||||
}
|
||||
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
||||
builtin_tools.push_back(name);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
|
||||
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||
if (handle_builtin_tool(name, parameters)) {
|
||||
return;
|
||||
}
|
||||
builder.resolve_refs(parameters);
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
|
@ -388,30 +432,42 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
|||
" \"}\""));
|
||||
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
|
||||
});
|
||||
tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*"));
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
if (!builtin_tools.empty()) {
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
}
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
}, grammar_options);
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {
|
||||
{"builtin_tools", builtin_tools},
|
||||
{"tools_in_user_message", false},
|
||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||
});
|
||||
data.format = "llama 3.1 tool calls";
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\((.*)\)");
|
||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
||||
|
||||
std::smatch match;
|
||||
if (std::regex_match(input, match, builtin_call_regex)) {
|
||||
auto arguments = json::parse("[" + match[2].str() + "]");
|
||||
auto name = match[1].str();
|
||||
auto raw_args = match[2].str();
|
||||
|
||||
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
|
||||
auto it_eq = raw_args.find('=');
|
||||
auto arg_name = raw_args.substr(0, it_eq);
|
||||
auto arg_value_str = raw_args.substr(it_eq + 1);
|
||||
auto arg_value = json::parse(arg_value_str);
|
||||
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ match[1],
|
||||
/* .arguments = */ arguments.dump(),
|
||||
/* .arguments = */ (json {
|
||||
{arg_name, arg_value},
|
||||
}).dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
},
|
||||
|
@ -423,7 +479,6 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
|
@ -462,7 +517,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
|
@ -490,7 +544,6 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
|
@ -529,7 +582,6 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
common_chat_data data;
|
||||
|
@ -574,7 +626,6 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
common_chat_data data;
|
||||
|
@ -651,7 +702,6 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||
data.grammar_lazy = params.tool_choice != "required";
|
||||
|
@ -705,9 +755,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
|||
if (!parse_json(it, end, call)) {
|
||||
throw std::runtime_error("Failed to parse json tool call");
|
||||
}
|
||||
const auto & arguments = call["arguments"];
|
||||
result.tool_calls.push_back({
|
||||
call["name"],
|
||||
call["arguments"].dump(),
|
||||
arguments.dump(),
|
||||
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
rit = {it, end, middle_pattern};
|
||||
|
@ -734,7 +786,6 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
common_chat_data data;
|
||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||
data.format = "content-only";
|
||||
|
|
|
@ -63,6 +63,8 @@ WEATHER_TOOL = {
|
|||
|
||||
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
||||
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
||||
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
||||
|
@ -78,8 +80,6 @@ WEATHER_TOOL = {
|
|||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
||||
# TODO: fix these
|
||||
# ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||
# ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
||||
])
|
||||
def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
|
||||
n_predict = 512
|
||||
|
@ -118,6 +118,8 @@ def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argu
|
|||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [
|
||||
(TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||
|
@ -139,8 +141,6 @@ def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argu
|
|||
# TODO: fix these
|
||||
# (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
||||
# (TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
# (PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
])
|
||||
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||
n_predict = 512
|
||||
|
@ -218,6 +218,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
|
|||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("hf_repo,hf_file,template_override", [
|
||||
("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||
("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||
|
@ -229,7 +230,6 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
|
|||
("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
# TODO: fix these
|
||||
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
||||
# ("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
])
|
||||
def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||
global server
|
||||
|
@ -267,6 +267,7 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
|
|||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
|
||||
('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
|
@ -277,7 +278,6 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
|
|||
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
|
||||
# TODO: fix these
|
||||
# ('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
||||
])
|
||||
def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||
|
|
|
@ -119,7 +119,25 @@ const auto python_tool = json::parse(R"({
|
|||
}
|
||||
}
|
||||
})");
|
||||
const auto code_interpreter_tool = json::parse(R"({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "code_interpreter",
|
||||
"description": "an ipython interpreter",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python code to execute."
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
})");
|
||||
const json tools = {special_function_tool, python_tool};
|
||||
const json llama_3_1_tools = {special_function_tool, code_interpreter_tool};
|
||||
|
||||
// static void test_parsing() {
|
||||
// json request = {
|
||||
|
@ -427,6 +445,19 @@ static void test_grammars() {
|
|||
}},
|
||||
}}}
|
||||
};
|
||||
auto code_interpreter_tool_call_message = json {
|
||||
{"role", "assistant"},
|
||||
{"content", {}},
|
||||
{"tool_calls", json {{
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "code_interpreter"},
|
||||
{"arguments", {
|
||||
{"code", "print('hey')"},
|
||||
}},
|
||||
}},
|
||||
}}}
|
||||
};
|
||||
|
||||
|
||||
common_chat_params no_tools_params;
|
||||
|
@ -494,10 +525,12 @@ static void test_grammars() {
|
|||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
test_template(tmpl, end_tokens, text_message, tools);
|
||||
// assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
// test_template(tmpl, end_tokens, text_message, tools);
|
||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools);
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools);
|
||||
test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue