fix test-chat-handler grammar tests

This commit is contained in:
ochafik 2025-01-27 20:13:09 +00:00
parent 118f799ae4
commit add9124115
3 changed files with 82 additions and 59 deletions

View file

@ -363,6 +363,7 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) { static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
fprintf(stderr, "[%s]\n", __func__); fprintf(stderr, "[%s]\n", __func__);
// TODO: get from request body.
auto builtin_tools = json {"wolfram_alpha", "brave_search"}; auto builtin_tools = json {"wolfram_alpha", "brave_search"};
common_chat_data data; common_chat_data data;
@ -377,10 +378,16 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
tool_rules.push_back( tool_rules.push_back(
builder.add_rule( builder.add_rule(
name + "-call", name + "-call",
"\"<|python_tag|>\" \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + "\"{\" "
// " ( \"\\\"type\\\": \\\"function\\\", \" | space ) "
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) + builder.add_schema(name + "-args", parameters) +
" \"}\"")); " \"}\""));
if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
}
}); });
tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*"));
if (params.tool_choice != "required") { if (params.tool_choice != "required") {
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
} }
@ -391,11 +398,27 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
{"builtin_tools", builtin_tools}, {"builtin_tools", builtin_tools},
}); });
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg { data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
static std::regex function_regex("<\\|python_tag\\|>\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": "); static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
static std::regex close_regex("\\}"); static std::regex close_regex("\\}");
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\((.*)\)");
std::smatch match;
if (std::regex_match(input, match, builtin_call_regex)) {
auto arguments = json::parse("[" + match[2].str() + "]");
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ match[1],
/* .arguments = */ arguments.dump(),
/* .id = */ "",
},
},
};
}
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true); return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
}); });
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
return data; return data;
} }
@ -435,7 +458,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true); auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
return res; return res;
}); });
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
return data; return data;
} }
@ -590,9 +612,8 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
} else if (type != "string") { } else if (type != "string") {
throw std::runtime_error("Invalid type in python tool: " + type.dump()); throw std::runtime_error("Invalid type in python tool: " + type.dump());
} }
} else {
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
} }
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
}); });
if (has_raw_python) { if (has_raw_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));

View file

@ -129,6 +129,7 @@ class chat_template {
bool supports_tools() const { return supports_tools_; } bool supports_tools() const { return supports_tools_; }
bool supports_tool_calls() const { return supports_tool_calls_; } bool supports_tool_calls() const { return supports_tool_calls_; }
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; } bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
bool requires_object_arguments() const { return requires_object_arguments_; }
std::string apply( std::string apply(
const nlohmann::ordered_json & messages, const nlohmann::ordered_json & messages,
@ -201,12 +202,14 @@ class chat_template {
for (auto & tool_call : message.at("tool_calls")) { for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") { if (tool_call["type"] == "function") {
auto & function = tool_call.at("function"); auto & function = tool_call.at("function");
std::string arguments = function.at("arguments"); auto & arguments = function.at("arguments");
if (arguments.is_string()) {
try { try {
function["arguments"] = json::parse(arguments); arguments = json::parse(arguments.get<std::string>());
} catch (const std::exception & ecvt) { } catch (const std::exception & ecvt) {
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what()); fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
function["arguments"] = arguments; arguments = arguments;
}
} }
} }
} }

View file

@ -72,32 +72,17 @@ static std::string dump(const json & j) {
return minja::Value(j).dump(-1, /* to_json= */ true); return minja::Value(j).dump(-1, /* to_json= */ true);
} }
static void assert_msg_equals(const common_chat_msg & result, const std::string & expected_content, const json & expected_tool_calls) { static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
assert_equals(expected_content, result.content); assert_equals(expected.role, actual.role);
auto tool_calls = json::array(); assert_equals(expected.content, actual.content);
for (const auto & tc : result.tool_calls) { assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
auto arguments = tc.arguments; for (size_t i = 0; i < expected.tool_calls.size(); i++) {
try { const auto & expected_tool_call = expected.tool_calls[i];
arguments = dump(json::parse(arguments)); const auto & actual_tool_call = actual.tool_calls[i];
} catch (const std::exception & e) { assert_equals(expected_tool_call.name, actual_tool_call.name);
// ignore assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
assert_equals(expected_tool_call.id, actual_tool_call.id);
} }
auto tool_call = json {
{"type", "function"},
{"function", {
{"arguments", arguments},
{"name", tc.name},
}},
};
if (!tc.id.empty()) {
tool_call["id"] = tc.id;
}
tool_calls.push_back(tool_call);
}
// Reparse / dump w/ non-ordered JSON variant.
auto expected = nlohmann::json::parse(expected_tool_calls.dump()).dump();
auto actual = nlohmann::json::parse(tool_calls.dump()).dump();
assert_equals(expected, actual);
} }
const auto special_function_tool = json::parse(R"({ const auto special_function_tool = json::parse(R"({
@ -373,7 +358,19 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) { static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
// auto tool_call_style = common_tool_call_style_detect(tmpl); // auto tool_call_style = common_tool_call_style_detect(tmpl);
auto & tool_calls = tool_calling_message.at("tool_calls"); common_chat_msg expected_msg {
"assistant",
"",
{},
};
for (const auto & tc : tool_calling_message.at("tool_calls")) {
const auto & arguments = tc.at("function").at("arguments");
expected_msg.tool_calls.push_back({
tc.at("function").at("name").get<std::string>(),
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
tc.contains("id") ? tc.at("id").get<std::string>() : "",
});
}
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false, // Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
// get the diff and try and parse it w/ the grammar. // get the diff and try and parse it w/ the grammar.
@ -398,12 +395,12 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl; std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
const auto msg = chat_data.parser->parse_final(full_delta); const auto msg = chat_data.parser->parse_final(full_delta);
assert_msg_equals(msg, "", tool_calls); assert_msg_equals(expected_msg, msg);
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, { auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
{"role", "assistant"}, {"role", "assistant"},
{"content", {}}, {"content", {}},
{"tool_calls", tool_calls} {"tool_calls", tool_calling_message.at("tool_calls")}
}, tools); }, tools);
if (!match_string(content_less_delta, grammar.get())) { if (!match_string(content_less_delta, grammar.get())) {
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar); throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar);
@ -433,7 +430,9 @@ static void test_grammars() {
{"type", "function"}, {"type", "function"},
{"function", { {"function", {
{"name", "python"}, {"name", "python"},
{"arguments", "print('hey')"} {"arguments", {
{"code", "print('hey')"},
}},
}}, }},
}}} }}}
}; };
@ -442,12 +441,12 @@ static void test_grammars() {
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>"); const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
test_template(tmpl, { "</s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true); test_template(tmpl, { "</s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
} }
// { {
// const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"); const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
// // assert_equals(tmpl.requires_object_arguments_, true); // assert_equals(tmpl.requires_object_arguments_, true);
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
// test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools); test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
// } }
{ {
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>"); const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
@ -456,22 +455,22 @@ static void test_grammars() {
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"); const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
} }
// { {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>"); const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
// } }
// { {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>"); const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools); test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
// } }
{ {
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>"); const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
} }
// { {
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"); const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
// } }
{ {
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>"); const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);