fix test-chat-handler grammar tests

This commit is contained in:
ochafik 2025-01-27 20:13:09 +00:00
parent 118f799ae4
commit add9124115
3 changed files with 82 additions and 59 deletions

View file

@ -363,6 +363,7 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__);
// TODO: get from request body.
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
common_chat_data data;
@ -377,10 +378,16 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"<|python_tag|>\" \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
"\"{\" "
// " ( \"\\\"type\\\": \\\"function\\\", \" | space ) "
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
}
});
tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
}
@ -391,11 +398,27 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
{"builtin_tools", builtin_tools},
});
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
static std::regex function_regex("<\\|python_tag\\|>\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex close_regex("\\}");
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\((.*)\)");
std::smatch match;
if (std::regex_match(input, match, builtin_call_regex)) {
auto arguments = json::parse("[" + match[2].str() + "]");
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ match[1],
/* .arguments = */ arguments.dump(),
/* .id = */ "",
},
},
};
}
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
});
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
return data;
}
@ -435,7 +458,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
return res;
});
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
return data;
}
@ -590,9 +612,8 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
} else if (type != "string") {
throw std::runtime_error("Invalid type in python tool: " + type.dump());
}
} else {
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
}
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
});
if (has_raw_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));

View file

@ -129,6 +129,7 @@ class chat_template {
bool supports_tools() const { return supports_tools_; }
bool supports_tool_calls() const { return supports_tool_calls_; }
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
bool requires_object_arguments() const { return requires_object_arguments_; }
std::string apply(
const nlohmann::ordered_json & messages,
@ -201,12 +202,14 @@ class chat_template {
for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
std::string arguments = function.at("arguments");
try {
function["arguments"] = json::parse(arguments);
} catch (const std::exception & ecvt) {
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
function["arguments"] = arguments;
auto & arguments = function.at("arguments");
if (arguments.is_string()) {
try {
arguments = json::parse(arguments.get<std::string>());
} catch (const std::exception & ecvt) {
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
arguments = arguments;
}
}
}
}

View file

@ -72,32 +72,17 @@ static std::string dump(const json & j) {
return minja::Value(j).dump(-1, /* to_json= */ true);
}
static void assert_msg_equals(const common_chat_msg & result, const std::string & expected_content, const json & expected_tool_calls) {
assert_equals(expected_content, result.content);
auto tool_calls = json::array();
for (const auto & tc : result.tool_calls) {
auto arguments = tc.arguments;
try {
arguments = dump(json::parse(arguments));
} catch (const std::exception & e) {
// ignore
}
auto tool_call = json {
{"type", "function"},
{"function", {
{"arguments", arguments},
{"name", tc.name},
}},
};
if (!tc.id.empty()) {
tool_call["id"] = tc.id;
}
tool_calls.push_back(tool_call);
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
assert_equals(expected.role, actual.role);
assert_equals(expected.content, actual.content);
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
const auto & expected_tool_call = expected.tool_calls[i];
const auto & actual_tool_call = actual.tool_calls[i];
assert_equals(expected_tool_call.name, actual_tool_call.name);
assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
assert_equals(expected_tool_call.id, actual_tool_call.id);
}
// Reparse / dump w/ non-ordered JSON variant.
auto expected = nlohmann::json::parse(expected_tool_calls.dump()).dump();
auto actual = nlohmann::json::parse(tool_calls.dump()).dump();
assert_equals(expected, actual);
}
const auto special_function_tool = json::parse(R"({
@ -373,7 +358,19 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
// auto tool_call_style = common_tool_call_style_detect(tmpl);
auto & tool_calls = tool_calling_message.at("tool_calls");
common_chat_msg expected_msg {
"assistant",
"",
{},
};
for (const auto & tc : tool_calling_message.at("tool_calls")) {
const auto & arguments = tc.at("function").at("arguments");
expected_msg.tool_calls.push_back({
tc.at("function").at("name").get<std::string>(),
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
tc.contains("id") ? tc.at("id").get<std::string>() : "",
});
}
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
// get the diff and try and parse it w/ the grammar.
@ -398,12 +395,12 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
const auto msg = chat_data.parser->parse_final(full_delta);
assert_msg_equals(msg, "", tool_calls);
assert_msg_equals(expected_msg, msg);
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
{"role", "assistant"},
{"content", {}},
{"tool_calls", tool_calls}
{"tool_calls", tool_calling_message.at("tool_calls")}
}, tools);
if (!match_string(content_less_delta, grammar.get())) {
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar);
@ -433,7 +430,9 @@ static void test_grammars() {
{"type", "function"},
{"function", {
{"name", "python"},
{"arguments", "print('hey')"}
{"arguments", {
{"code", "print('hey')"},
}},
}},
}}}
};
@ -442,12 +441,12 @@ static void test_grammars() {
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
test_template(tmpl, { "</s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
}
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
// // assert_equals(tmpl.requires_object_arguments_, true);
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
// test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
// }
{
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
// assert_equals(tmpl.requires_object_arguments_, true);
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
@ -456,22 +455,22 @@ static void test_grammars() {
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
}
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
// }
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
// }
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
}
// {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
// }
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);