Unify llama 3.x chat handling again (allow {"type": "function", "name": ...
prefix)
This commit is contained in:
parent
7b5e0803c8
commit
ba27e98582
5 changed files with 79 additions and 108 deletions
|
@ -344,7 +344,7 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool allow_python_tag_builtin_tools) {
|
||||||
auto builtin_tools = json::array();
|
auto builtin_tools = json::array();
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = params.tool_choice != "required";
|
||||||
|
@ -379,24 +379,31 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto has_function = false;
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
foreach_function(params.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
std::string name = function["name"];
|
std::string name = function["name"];
|
||||||
auto parameters = function["parameters"];
|
auto parameters = function["parameters"];
|
||||||
|
|
||||||
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||||
if (handle_builtin_tool(name, parameters)) {
|
if (allow_python_tag_builtin_tools && handle_builtin_tool(name, parameters)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
builder.resolve_refs(parameters);
|
builder.resolve_refs(parameters);
|
||||||
tool_rules.push_back(
|
tool_rules.push_back(
|
||||||
builder.add_rule(
|
builder.add_rule(
|
||||||
name + "-call",
|
name + "-call",
|
||||||
"\"{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
"\"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) "
|
||||||
|
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||||
builder.add_schema(name + "-args", parameters) +
|
builder.add_schema(name + "-args", parameters) +
|
||||||
" \"}\""));
|
" \"}\""));
|
||||||
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
|
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
|
||||||
|
has_function = true;
|
||||||
});
|
});
|
||||||
|
if (has_function) {
|
||||||
|
data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true});
|
||||||
|
data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true});
|
||||||
|
}
|
||||||
if (!builtin_tools.empty()) {
|
if (!builtin_tools.empty()) {
|
||||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||||
}
|
}
|
||||||
|
@ -407,79 +414,44 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
||||||
{"tools_in_user_message", false},
|
{"tools_in_user_message", false},
|
||||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||||
});
|
});
|
||||||
data.format = "llama 3.1 tool calls";
|
data.format = std::string("llama 3.x tool calls") + (allow_python_tag_builtin_tools ? " (w/ builtin tools)" : "");
|
||||||
data.parser = [params](const std::string & input) -> common_chat_msg {
|
data.parser = [params, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg {
|
||||||
static std::regex function_regex("\\{\"name\": \"([^\"]+)\", \"parameters\": ");
|
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||||
static std::regex close_regex("\\}");
|
static std::regex close_regex("\\}");
|
||||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
||||||
|
|
||||||
std::smatch match;
|
if (allow_python_tag_builtin_tools && !builtin_tools.empty()) {
|
||||||
if (std::regex_match(input, match, builtin_call_regex)) {
|
std::smatch match;
|
||||||
auto name = match[1].str();
|
if (std::regex_match(input, match, builtin_call_regex)) {
|
||||||
auto raw_args = match[2].str();
|
auto name = match[1].str();
|
||||||
|
auto raw_args = match[2].str();
|
||||||
|
|
||||||
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
|
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
|
||||||
auto it_eq = raw_args.find('=');
|
auto it_eq = raw_args.find('=');
|
||||||
auto arg_name = raw_args.substr(0, it_eq);
|
auto arg_name = raw_args.substr(0, it_eq);
|
||||||
auto arg_value_str = raw_args.substr(it_eq + 1);
|
auto arg_value_str = raw_args.substr(it_eq + 1);
|
||||||
auto arg_value = json::parse(arg_value_str);
|
auto arg_value = json::parse(arg_value_str);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
/* .role = */ "assistant",
|
/* .role = */ "assistant",
|
||||||
/* .content = */ match.prefix().str(),
|
/* .content = */ match.prefix().str(),
|
||||||
/* .tool_calls = */ {
|
/* .tool_calls = */ {
|
||||||
{
|
{
|
||||||
/* .name = */ match[1],
|
/* .name = */ match[1],
|
||||||
/* .arguments = */ (json {
|
/* .arguments = */ (json {
|
||||||
{arg_name, arg_value},
|
{arg_name, arg_value},
|
||||||
}).dump(),
|
}).dump(),
|
||||||
/* .id = */ "",
|
/* .id = */ "",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
};
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
||||||
};
|
};
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
|
||||||
common_chat_data data;
|
|
||||||
|
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
||||||
std::vector<std::string> tool_rules;
|
|
||||||
|
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
|
||||||
const auto & function = tool["function"];
|
|
||||||
std::string name = function["name"];
|
|
||||||
auto parameters = function["parameters"];
|
|
||||||
builder.resolve_refs(parameters);
|
|
||||||
tool_rules.push_back(
|
|
||||||
builder.add_rule(
|
|
||||||
name + "-call",
|
|
||||||
"\"{\" "
|
|
||||||
// " ( \"\\\"type\\\": \\\"function\\\", \" | space ) "
|
|
||||||
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
|
||||||
builder.add_schema(name + "-args", parameters) +
|
|
||||||
" \"}\""));
|
|
||||||
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
|
|
||||||
});
|
|
||||||
|
|
||||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
|
||||||
}, grammar_options);
|
|
||||||
data.additional_stops.push_back("<|eom_id|>");
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {});
|
|
||||||
data.format = "llama 3.2 tool calls";
|
|
||||||
data.parser = [params](const std::string & input) {
|
|
||||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
|
||||||
static std::regex close_regex("\\}");
|
|
||||||
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
|
||||||
return res;
|
|
||||||
};
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
data.grammar_lazy = params.tool_choice != "required";
|
||||||
|
@ -559,8 +531,8 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
||||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
|
||||||
data.grammar_lazy = params.tool_choice != "required";
|
|
||||||
if (!params.tools.is_null() && !params.tools.empty()) {
|
if (!params.tools.is_null() && !params.tools.empty()) {
|
||||||
|
data.grammar_lazy = params.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> first_tool_rules;
|
std::vector<std::string> first_tool_rules;
|
||||||
std::vector<std::string> subsequent_tool_rules;
|
std::vector<std::string> subsequent_tool_rules;
|
||||||
|
@ -806,13 +778,8 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
|
||||||
return common_chat_init_functionary_v3_1_llama_3_1_tool_call(tmpl, params);
|
return common_chat_init_functionary_v3_1_llama_3_1_tool_call(tmpl, params);
|
||||||
}
|
}
|
||||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||||
auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos;
|
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||||
|
return common_chat_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
||||||
if (uses_python_tag) {
|
|
||||||
return common_chat_init_llama_3_1_python_tag_tool_calls(tmpl, params);
|
|
||||||
} else {
|
|
||||||
return common_chat_init_llama_3_2_tool_calls(tmpl, params);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
||||||
return common_chat_init_deepseek_r1_tool_call(tmpl, params);
|
return common_chat_init_deepseek_r1_tool_call(tmpl, params);
|
||||||
|
|
|
@ -3800,6 +3800,8 @@ int main(int argc, char ** argv) {
|
||||||
/* .grammar = */ json_value(data, "grammar", std::string("")),
|
/* .grammar = */ json_value(data, "grammar", std::string("")),
|
||||||
});
|
});
|
||||||
LOG_INF("Chat format: %s\n", chat_data.format.c_str());
|
LOG_INF("Chat format: %s\n", chat_data.format.c_str());
|
||||||
|
LOG_DBG("Prompt: %s\n", chat_data.prompt.get<std::string>().c_str());
|
||||||
|
LOG_DBG("Grammar: %s\n", chat_data.grammar.c_str());
|
||||||
if (data.contains("grammar")) {
|
if (data.contains("grammar")) {
|
||||||
if (!chat_data.grammar.empty()) {
|
if (!chat_data.grammar.empty()) {
|
||||||
throw std::runtime_error("Cannot provide grammar and tools");
|
throw std::runtime_error("Cannot provide grammar and tools");
|
||||||
|
@ -3841,11 +3843,11 @@ int main(int argc, char ** argv) {
|
||||||
for (const auto & trigger : chat_data.grammar_triggers) {
|
for (const auto & trigger : chat_data.grammar_triggers) {
|
||||||
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||||
if (ids.size() == 1) {
|
if (ids.size() == 1) {
|
||||||
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
|
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||||
task.params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
task.params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
|
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||||
task.params.sampling.grammar_trigger_words.push_back(trigger);
|
task.params.sampling.grammar_trigger_words.push_back(trigger);
|
||||||
}
|
}
|
||||||
task.params.antiprompt = chat_data.additional_stops;
|
task.params.antiprompt = chat_data.additional_stops;
|
||||||
|
@ -4021,6 +4023,7 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
LOG_DBG("request: %s\n", req.body.c_str());
|
||||||
if (ctx_server.params_base.embedding) {
|
if (ctx_server.params_base.embedding) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -58,7 +58,7 @@ WEATHER_TOOL = {
|
||||||
"required":["location"]
|
"required":["location"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}# TODO: fix this crash
|
}
|
||||||
|
|
||||||
|
|
||||||
def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
|
def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
|
||||||
|
@ -132,8 +132,8 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [
|
@pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [
|
||||||
(TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
(PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||||
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||||
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||||
|
@ -231,7 +231,7 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||||
# TODO: fix this crash
|
# TODO: fix this crash
|
||||||
# ("meetkai-functionary-medium-v3.2", 256, [], None),
|
("meetkai-functionary-medium-v3.2", 256, [], None),
|
||||||
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
|
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
|
||||||
("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'),
|
("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'),
|
||||||
("meetkai-functionary-medium-v3.1", 256, [], None),
|
("meetkai-functionary-medium-v3.1", 256, [], None),
|
||||||
|
@ -247,9 +247,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("hf_repo,hf_file,template_override", [
|
@pytest.mark.parametrize("hf_repo,hf_file,template_override", [
|
||||||
# TODO: fix these
|
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
# ("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
|
||||||
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
|
||||||
("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||||
("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||||
("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||||
|
@ -259,6 +257,8 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
||||||
("bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai/functionary-medium-v3.2", None)),
|
("bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai/functionary-medium-v3.2", None)),
|
||||||
("bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
("bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||||
("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||||
|
# TODO: fix this (times out)
|
||||||
|
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
||||||
])
|
])
|
||||||
def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||||
global server
|
global server
|
||||||
|
@ -276,7 +276,6 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
"max_tokens": 256,
|
"max_tokens": 256,
|
||||||
"messages": [
|
"messages": [
|
||||||
# {"role": "system", "content": "Use tools as appropriate."},
|
|
||||||
{"role": "user", "content": "What is the weather in Istanbul?"},
|
{"role": "user", "content": "What is the weather in Istanbul?"},
|
||||||
],
|
],
|
||||||
"tools": [WEATHER_TOOL],
|
"tools": [WEATHER_TOOL],
|
||||||
|
@ -295,21 +294,21 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
|
@pytest.mark.parametrize("expected_arguments_override,hf_repo,hf_file,template_override", [
|
||||||
# TODO: fix these
|
|
||||||
# ('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
|
||||||
# (None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
|
||||||
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
|
||||||
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
|
||||||
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
|
||||||
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
|
||||||
(None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
(None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||||
(None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
(None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||||
|
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||||
|
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
|
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||||
|
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||||
|
(None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||||
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||||
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
|
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
|
||||||
|
# TODO: fix this (times out)
|
||||||
|
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
|
||||||
])
|
])
|
||||||
def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||||
global server
|
global server
|
||||||
server.n_slots = 1
|
server.n_slots = 1
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
|
@ -319,7 +318,7 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_
|
||||||
server.model_hf_file = hf_file
|
server.model_hf_file = hf_file
|
||||||
if template_override:
|
if template_override:
|
||||||
(template_hf_repo, template_variant) = template_override
|
(template_hf_repo, template_variant) = template_override
|
||||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
server.start(timeout_seconds=15*60)
|
server.start(timeout_seconds=15*60)
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
@ -327,7 +326,6 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a coding assistant."},
|
{"role": "system", "content": "You are a coding assistant."},
|
||||||
{"role": "user", "content": "say hello world with python"},
|
{"role": "user", "content": "say hello world with python"},
|
||||||
# {"role": "user", "content": "Print a hello world message with python"},
|
|
||||||
],
|
],
|
||||||
"tools": [PYTHON_TOOL],
|
"tools": [PYTHON_TOOL],
|
||||||
# Note: without these greedy params, Functionary v3.2 writes `def hello_world():\n print("Hello, World!")\nhello_world()` which is correct but a pain to test.
|
# Note: without these greedy params, Functionary v3.2 writes `def hello_world():\n print("Hello, World!")\nhello_world()` which is correct but a pain to test.
|
||||||
|
@ -342,8 +340,8 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_
|
||||||
tool_call = tool_calls[0]
|
tool_call = tool_calls[0]
|
||||||
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
||||||
actual_arguments = tool_call["function"]["arguments"]
|
actual_arguments = tool_call["function"]["arguments"]
|
||||||
if expected_arguments is not None:
|
if expected_arguments_override is not None:
|
||||||
assert actual_arguments == expected_arguments
|
assert actual_arguments == expected_arguments_override
|
||||||
else:
|
else:
|
||||||
actual_arguments = json.loads(actual_arguments)
|
actual_arguments = json.loads(actual_arguments)
|
||||||
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
||||||
|
|
|
@ -1170,6 +1170,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||||
grammar.awaiting_trigger = false;
|
grammar.awaiting_trigger = false;
|
||||||
grammar.trigger_buffer.clear();
|
grammar.trigger_buffer.clear();
|
||||||
llama_grammar_accept_str(grammar, piece);
|
llama_grammar_accept_str(grammar, piece);
|
||||||
|
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
|
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
|
||||||
|
@ -1181,9 +1182,11 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||||
auto constrained_str = grammar.trigger_buffer.substr(pos);
|
auto constrained_str = grammar.trigger_buffer.substr(pos);
|
||||||
grammar.trigger_buffer.clear();
|
grammar.trigger_buffer.clear();
|
||||||
llama_grammar_accept_str(grammar, constrained_str);
|
llama_grammar_accept_str(grammar, constrained_str);
|
||||||
|
LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -372,8 +372,8 @@ static void test_template_output_parsers() {
|
||||||
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(tmpl, tools_params));
|
||||||
assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||||
|
|
||||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||||
|
@ -383,6 +383,17 @@ static void test_template_output_parsers() {
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
|
assert_equals(std::string("llama 3.x tool calls"), describe(tmpl, tools_params));
|
||||||
|
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools,
|
||||||
|
"Hello, world!", /* skip_grammar_test= */ true);
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
@ -394,17 +405,6 @@ static void test_template_output_parsers() {
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
"<function=special_function>{\"arg1\": 1}</function>");
|
"<function=special_function>{\"arg1\": 1}</function>");
|
||||||
}
|
}
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools,
|
|
||||||
"Hello, world!", /* skip_grammar_test= */ true);
|
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
|
||||||
}
|
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue