fix test-chat-handler grammar tests
This commit is contained in:
parent
118f799ae4
commit
add9124115
3 changed files with 82 additions and 59 deletions
|
@ -363,6 +363,7 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
|
|||
|
||||
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
fprintf(stderr, "[%s]\n", __func__);
|
||||
// TODO: get from request body.
|
||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||
common_chat_data data;
|
||||
|
||||
|
@ -377,10 +378,16 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
|||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"<|python_tag|>\" \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
"\"{\" "
|
||||
// " ( \"\\\"type\\\": \\\"function\\\", \" | space ) "
|
||||
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
builder.add_schema(name + "-args", parameters) +
|
||||
" \"}\""));
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
|
||||
}
|
||||
});
|
||||
tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*"));
|
||||
if (params.tool_choice != "required") {
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
}
|
||||
|
@ -391,11 +398,27 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
|||
{"builtin_tools", builtin_tools},
|
||||
});
|
||||
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||
static std::regex function_regex("<\\|python_tag\\|>\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\((.*)\)");
|
||||
|
||||
std::smatch match;
|
||||
if (std::regex_match(input, match, builtin_call_regex)) {
|
||||
auto arguments = json::parse("[" + match[2].str() + "]");
|
||||
return {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ match.prefix().str(),
|
||||
/* .tool_calls = */ {
|
||||
{
|
||||
/* .name = */ match[1],
|
||||
/* .arguments = */ arguments.dump(),
|
||||
/* .id = */ "",
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
||||
});
|
||||
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -435,7 +458,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
|
|||
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
|
||||
return res;
|
||||
});
|
||||
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -590,9 +612,8 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
|
|||
} else if (type != "string") {
|
||||
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
||||
}
|
||||
} else {
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
}
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
});
|
||||
if (has_raw_python) {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
|
|
|
@ -129,6 +129,7 @@ class chat_template {
|
|||
bool supports_tools() const { return supports_tools_; }
|
||||
bool supports_tool_calls() const { return supports_tool_calls_; }
|
||||
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
|
||||
bool requires_object_arguments() const { return requires_object_arguments_; }
|
||||
|
||||
std::string apply(
|
||||
const nlohmann::ordered_json & messages,
|
||||
|
@ -201,12 +202,14 @@ class chat_template {
|
|||
for (auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call["type"] == "function") {
|
||||
auto & function = tool_call.at("function");
|
||||
std::string arguments = function.at("arguments");
|
||||
try {
|
||||
function["arguments"] = json::parse(arguments);
|
||||
} catch (const std::exception & ecvt) {
|
||||
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
|
||||
function["arguments"] = arguments;
|
||||
auto & arguments = function.at("arguments");
|
||||
if (arguments.is_string()) {
|
||||
try {
|
||||
arguments = json::parse(arguments.get<std::string>());
|
||||
} catch (const std::exception & ecvt) {
|
||||
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
|
||||
arguments = arguments;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -72,32 +72,17 @@ static std::string dump(const json & j) {
|
|||
return minja::Value(j).dump(-1, /* to_json= */ true);
|
||||
}
|
||||
|
||||
static void assert_msg_equals(const common_chat_msg & result, const std::string & expected_content, const json & expected_tool_calls) {
|
||||
assert_equals(expected_content, result.content);
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tc : result.tool_calls) {
|
||||
auto arguments = tc.arguments;
|
||||
try {
|
||||
arguments = dump(json::parse(arguments));
|
||||
} catch (const std::exception & e) {
|
||||
// ignore
|
||||
}
|
||||
auto tool_call = json {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", arguments},
|
||||
{"name", tc.name},
|
||||
}},
|
||||
};
|
||||
if (!tc.id.empty()) {
|
||||
tool_call["id"] = tc.id;
|
||||
}
|
||||
tool_calls.push_back(tool_call);
|
||||
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
|
||||
assert_equals(expected.role, actual.role);
|
||||
assert_equals(expected.content, actual.content);
|
||||
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
|
||||
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
|
||||
const auto & expected_tool_call = expected.tool_calls[i];
|
||||
const auto & actual_tool_call = actual.tool_calls[i];
|
||||
assert_equals(expected_tool_call.name, actual_tool_call.name);
|
||||
assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
|
||||
assert_equals(expected_tool_call.id, actual_tool_call.id);
|
||||
}
|
||||
// Reparse / dump w/ non-ordered JSON variant.
|
||||
auto expected = nlohmann::json::parse(expected_tool_calls.dump()).dump();
|
||||
auto actual = nlohmann::json::parse(tool_calls.dump()).dump();
|
||||
assert_equals(expected, actual);
|
||||
}
|
||||
|
||||
const auto special_function_tool = json::parse(R"({
|
||||
|
@ -373,7 +358,19 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
|
|||
|
||||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
|
||||
// auto tool_call_style = common_tool_call_style_detect(tmpl);
|
||||
auto & tool_calls = tool_calling_message.at("tool_calls");
|
||||
common_chat_msg expected_msg {
|
||||
"assistant",
|
||||
"",
|
||||
{},
|
||||
};
|
||||
for (const auto & tc : tool_calling_message.at("tool_calls")) {
|
||||
const auto & arguments = tc.at("function").at("arguments");
|
||||
expected_msg.tool_calls.push_back({
|
||||
tc.at("function").at("name").get<std::string>(),
|
||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
tc.contains("id") ? tc.at("id").get<std::string>() : "",
|
||||
});
|
||||
}
|
||||
|
||||
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
|
||||
// get the diff and try and parse it w/ the grammar.
|
||||
|
@ -398,12 +395,12 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
|
||||
|
||||
const auto msg = chat_data.parser->parse_final(full_delta);
|
||||
assert_msg_equals(msg, "", tool_calls);
|
||||
assert_msg_equals(expected_msg, msg);
|
||||
|
||||
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
||||
{"role", "assistant"},
|
||||
{"content", {}},
|
||||
{"tool_calls", tool_calls}
|
||||
{"tool_calls", tool_calling_message.at("tool_calls")}
|
||||
}, tools);
|
||||
if (!match_string(content_less_delta, grammar.get())) {
|
||||
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar);
|
||||
|
@ -433,7 +430,9 @@ static void test_grammars() {
|
|||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "python"},
|
||||
{"arguments", "print('hey')"}
|
||||
{"arguments", {
|
||||
{"code", "print('hey')"},
|
||||
}},
|
||||
}},
|
||||
}}}
|
||||
};
|
||||
|
@ -442,12 +441,12 @@ static void test_grammars() {
|
|||
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||
test_template(tmpl, { "</s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
|
||||
}
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
||||
// // assert_equals(tmpl.requires_object_arguments_, true);
|
||||
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||
// test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
|
||||
// }
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
||||
// assert_equals(tmpl.requires_object_arguments_, true);
|
||||
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||
test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||
|
@ -456,22 +455,22 @@ static void test_grammars() {
|
|||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
|
||||
}
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
// }
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
|
||||
// }
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
}
|
||||
// {
|
||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
// }
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
||||
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue