Fix Llama 3.1 (incl. constrained builtin tools e.g. <|python_tag|>foo.call(arg=vallue)
)
This commit is contained in:
parent
2d607f1a68
commit
ef9efc9ed3
3 changed files with 116 additions and 32 deletions
|
@ -207,7 +207,6 @@ static void foreach_function(const json & tools, const std::function<void(const
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
|
||||||
auto tool_call_schemas = json::array();
|
auto tool_call_schemas = json::array();
|
||||||
|
@ -318,7 +317,6 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = params.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
@ -358,25 +356,71 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||||
data.format = "mistral nemo tool calls";
|
data.format = "mistral nemo tool calls";
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
|
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
|
||||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||||
});
|
});
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
||||||
// TODO: get from request body.
|
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
|
||||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
}
|
||||||
common_chat_data data;
|
const auto & parameters_properties = parameters.at("properties");
|
||||||
|
const auto & parameters_required = parameters.at("required");
|
||||||
|
for (const auto & prop : expected_properties) {
|
||||||
|
if (!parameters_properties.contains(prop)) {
|
||||||
|
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
|
||||||
|
}
|
||||||
|
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
|
||||||
|
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (parameters_properties.size() != expected_properties.size()) {
|
||||||
|
throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", "));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
auto builtin_tools = json::array();
|
||||||
|
common_chat_data data;
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = params.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
|
|
||||||
|
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
||||||
|
if (name == "wolfram_alpha") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
||||||
|
expect_tool_parameters(name, parameters, {"query"});
|
||||||
|
} else if (name == "web_search" || name == "brave_search") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
||||||
|
expect_tool_parameters(name, parameters, {"query"});
|
||||||
|
} else if (name == "python" || name == "code_interpreter") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
||||||
|
expect_tool_parameters(name, parameters, {"code"});
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> kvs;
|
||||||
|
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||||
|
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_rules.push_back(
|
||||||
|
builder.add_rule(
|
||||||
|
name + "-call",
|
||||||
|
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
||||||
|
builtin_tools.push_back(name);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
foreach_function(params.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
std::string name = function["name"];
|
std::string name = function["name"];
|
||||||
auto parameters = function["parameters"];
|
auto parameters = function["parameters"];
|
||||||
|
|
||||||
|
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||||
|
if (handle_builtin_tool(name, parameters)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
builder.resolve_refs(parameters);
|
builder.resolve_refs(parameters);
|
||||||
tool_rules.push_back(
|
tool_rules.push_back(
|
||||||
builder.add_rule(
|
builder.add_rule(
|
||||||
|
@ -388,30 +432,42 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
||||||
" \"}\""));
|
" \"}\""));
|
||||||
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
|
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
|
||||||
});
|
});
|
||||||
tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*"));
|
if (!builtin_tools.empty()) {
|
||||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||||
|
}
|
||||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.additional_stops.push_back("<|eom_id|>");
|
data.additional_stops.push_back("<|eom_id|>");
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {
|
||||||
{"builtin_tools", builtin_tools},
|
{"tools_in_user_message", false},
|
||||||
|
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||||
});
|
});
|
||||||
data.format = "llama 3.1 tool calls";
|
data.format = "llama 3.1 tool calls";
|
||||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||||
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||||
static std::regex close_regex("\\}");
|
static std::regex close_regex("\\}");
|
||||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\((.*)\)");
|
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
||||||
|
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
if (std::regex_match(input, match, builtin_call_regex)) {
|
if (std::regex_match(input, match, builtin_call_regex)) {
|
||||||
auto arguments = json::parse("[" + match[2].str() + "]");
|
auto name = match[1].str();
|
||||||
|
auto raw_args = match[2].str();
|
||||||
|
|
||||||
|
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
|
||||||
|
auto it_eq = raw_args.find('=');
|
||||||
|
auto arg_name = raw_args.substr(0, it_eq);
|
||||||
|
auto arg_value_str = raw_args.substr(it_eq + 1);
|
||||||
|
auto arg_value = json::parse(arg_value_str);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
/* .role = */ "assistant",
|
/* .role = */ "assistant",
|
||||||
/* .content = */ match.prefix().str(),
|
/* .content = */ match.prefix().str(),
|
||||||
/* .tool_calls = */ {
|
/* .tool_calls = */ {
|
||||||
{
|
{
|
||||||
/* .name = */ match[1],
|
/* .name = */ match[1],
|
||||||
/* .arguments = */ arguments.dump(),
|
/* .arguments = */ (json {
|
||||||
|
{arg_name, arg_value},
|
||||||
|
}).dump(),
|
||||||
/* .id = */ "",
|
/* .id = */ "",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -423,7 +479,6 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = params.tool_choice != "required";
|
||||||
|
@ -462,7 +517,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = params.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
@ -490,7 +544,6 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = params.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
@ -529,7 +582,6 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
|
||||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
@ -574,7 +626,6 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
|
||||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
@ -651,7 +702,6 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = params.tool_choice != "required";
|
||||||
|
@ -705,9 +755,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
||||||
if (!parse_json(it, end, call)) {
|
if (!parse_json(it, end, call)) {
|
||||||
throw std::runtime_error("Failed to parse json tool call");
|
throw std::runtime_error("Failed to parse json tool call");
|
||||||
}
|
}
|
||||||
|
const auto & arguments = call["arguments"];
|
||||||
result.tool_calls.push_back({
|
result.tool_calls.push_back({
|
||||||
call["name"],
|
call["name"],
|
||||||
call["arguments"].dump(),
|
arguments.dump(),
|
||||||
|
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||||
/* id= */ "",
|
/* id= */ "",
|
||||||
});
|
});
|
||||||
rit = {it, end, middle_pattern};
|
rit = {it, end, middle_pattern};
|
||||||
|
@ -734,7 +786,6 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
|
||||||
data.format = "content-only";
|
data.format = "content-only";
|
||||||
|
|
|
@ -63,6 +63,8 @@ WEATHER_TOOL = {
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||||
|
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||||
|
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
||||||
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||||
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
||||||
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
||||||
|
@ -78,8 +80,6 @@ WEATHER_TOOL = {
|
||||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
||||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
||||||
# TODO: fix these
|
# TODO: fix these
|
||||||
# ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
|
||||||
# ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
|
||||||
])
|
])
|
||||||
def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
|
def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
|
||||||
n_predict = 512
|
n_predict = 512
|
||||||
|
@ -118,6 +118,8 @@ def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argu
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [
|
@pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [
|
||||||
|
(TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
|
(PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||||
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||||
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||||
|
@ -139,8 +141,6 @@ def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argu
|
||||||
# TODO: fix these
|
# TODO: fix these
|
||||||
# (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
# (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
||||||
# (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
# (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
||||||
# (TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
|
||||||
# (PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
|
||||||
])
|
])
|
||||||
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||||
n_predict = 512
|
n_predict = 512
|
||||||
|
@ -218,6 +218,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("hf_repo,hf_file,template_override", [
|
@pytest.mark.parametrize("hf_repo,hf_file,template_override", [
|
||||||
|
("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||||
("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||||
("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||||
|
@ -229,7 +230,6 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
|
||||||
("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||||
# TODO: fix these
|
# TODO: fix these
|
||||||
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
||||||
# ("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
|
||||||
])
|
])
|
||||||
def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||||
global server
|
global server
|
||||||
|
@ -267,6 +267,7 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
|
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
|
||||||
|
('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||||
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||||
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||||
|
@ -277,7 +278,6 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
|
||||||
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||||
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
|
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
|
||||||
# TODO: fix these
|
# TODO: fix these
|
||||||
# ('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
|
||||||
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
||||||
])
|
])
|
||||||
def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||||
|
|
|
@ -119,7 +119,25 @@ const auto python_tool = json::parse(R"({
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})");
|
})");
|
||||||
|
const auto code_interpreter_tool = json::parse(R"({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "code_interpreter",
|
||||||
|
"description": "an ipython interpreter",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"code": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Python code to execute."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["code"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})");
|
||||||
const json tools = {special_function_tool, python_tool};
|
const json tools = {special_function_tool, python_tool};
|
||||||
|
const json llama_3_1_tools = {special_function_tool, code_interpreter_tool};
|
||||||
|
|
||||||
// static void test_parsing() {
|
// static void test_parsing() {
|
||||||
// json request = {
|
// json request = {
|
||||||
|
@ -427,6 +445,19 @@ static void test_grammars() {
|
||||||
}},
|
}},
|
||||||
}}}
|
}}}
|
||||||
};
|
};
|
||||||
|
auto code_interpreter_tool_call_message = json {
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"content", {}},
|
||||||
|
{"tool_calls", json {{
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"name", "code_interpreter"},
|
||||||
|
{"arguments", {
|
||||||
|
{"code", "print('hey')"},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
}}}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
common_chat_params no_tools_params;
|
common_chat_params no_tools_params;
|
||||||
|
@ -494,10 +525,12 @@ static void test_grammars() {
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
// assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||||
test_template(tmpl, end_tokens, text_message, tools);
|
// test_template(tmpl, end_tokens, text_message, tools);
|
||||||
|
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools);
|
||||||
|
test_template(tmpl, end_tokens, python_tool_call_message, tools);
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools);
|
test_template(tmpl, end_tokens, tool_call_message, tools);
|
||||||
test_template(tmpl, end_tokens, python_tool_call_message, tools);
|
test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue