tool-call: support Command R7B (+ return tool_plan "thoughts" in API) (#11585)

* `tool-call`: support Command R7B (w/ tool_plan return)

* `tool-call`: cleaner preservation of tokens + warn when likely bad chat template override

* `tool-call`: test cleanup / handle lazy grammar triggers
This commit is contained in:
Olivier Chafik 2025-02-02 09:25:38 +00:00 committed by GitHub
parent 69804487e0
commit bfcce4d693
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 420 additions and 56 deletions

View file

@ -22,9 +22,13 @@ static common_chat_msg msg_from_json(const json & message) {
"assistant",
"",
{},
/* .tool_plan = */ "",
};
if (message.contains("content") && !message.at("content").is_null()) {
ret.content = message.at("content").get<std::string>();
ret.content = message.at("content");
}
if (message.contains("tool_plan")) {
ret.tool_plan = message.at("tool_plan");
}
auto has_tool_calls = message.contains("tool_calls");
if (has_tool_calls) {
@ -171,8 +175,7 @@ const json llama_3_1_tools = { special_function_tool, code_interpreter_too
struct delta_data {
std::string delta;
std::string grammar;
common_chat_format format;
common_chat_params params;
};
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
@ -214,7 +217,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
break;
}
}
return { delta, params_full.grammar, params_full.format };
return { delta, params_full };
}
/*
@ -224,7 +227,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
*/
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
bool skip_grammar_test = false, bool skip_parser_test = false) {
bool expect_grammar_triggered = true) {
common_chat_msg expected_msg = msg_from_json(test_message);
auto user_message = json{
@ -238,45 +241,110 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
assert_equals(expected_delta, data.delta);
}
if (!skip_parser_test) {
const auto msg = common_chat_parse(data.delta, data.format);
if (expect_grammar_triggered) {
const auto msg = common_chat_parse(data.delta, data.params.format);
assert_msg_equals(expected_msg, msg);
}
if (!expected_msg.tool_calls.empty()) {
GGML_ASSERT(!data.grammar.empty());
GGML_ASSERT(!data.params.grammar.empty());
}
if (!data.grammar.empty()) {
auto grammar = build_grammar(data.grammar);
if (!data.params.grammar.empty()) {
auto grammar = build_grammar(data.params.grammar);
if (!grammar) {
throw std::runtime_error("Failed to build grammar");
}
// TODO: exercice lazy grammars + triggers here, instead of skipping the test
if (!skip_grammar_test) {
if (!match_string(data.delta, grammar.get())) {
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
"\n\nGrammar: " + data.grammar);
auto earliest_trigger_pos = std::string::npos;
auto constrained = data.delta;
for (const auto & trigger : data.params.grammar_triggers) {
auto pos = constrained.find(trigger.word);
if (pos == std::string::npos) {
continue;
}
if (pos > 0 && trigger.at_start) {
fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
continue;
}
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
earliest_trigger_pos = pos;
}
}
auto grammar_triggered = false;
if (earliest_trigger_pos != std::string::npos) {
constrained = constrained.substr(earliest_trigger_pos);
grammar_triggered = true;
}
if (data.params.grammar_lazy) {
assert_equals(expect_grammar_triggered, grammar_triggered);
}
if (grammar_triggered && !match_string(constrained, grammar.get())) {
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
"\n\nGrammar: " + data.params.grammar);
}
}
}
}
static void test_template_output_parsers() {
auto text_message = json{
json text_message {
{ "role", "assistant" },
{ "content", "Hello, world!" },
};
auto tool_call_message = json{
json tool_calls = json::array({{
{ "type", "function" },
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
}});
json tool_call_message {
{ "role", "assistant"},
{ "content", {}},
{ "tool_calls", {
{
{ "type", "function" },
{ "function", {
{ "name", "special_function" },
{ "arguments", "{\"arg1\": 1}" },
}},
},
}},
};
json tool_call_message_with_id {
{ "role", "assistant"},
{ "content", {}},
{ "tool_calls", {
{
{ "type", "function" },
{ "function", {
{ "name", "special_function" },
{ "arguments", "{\"arg1\": 1}" },
}},
{"id", "123456789"},
},
}},
{ "role", "assistant" },
{ "content", {} },
{ "tool_calls", json{ {
{ "type", "function" },
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
} } }
{ "tool_calls", tool_calls }
};
json tool_call_plan_message_with_idx {
{ "role", "assistant"},
{ "content", {}},
{ "tool_plan", "I'm not so sure"},
{ "tool_calls", {
{
{ "type", "function" },
{ "function", {
{ "name", "special_function" },
{ "arguments", "{\"arg1\": 1}" },
}},
// Index of the tool call in the tool_calls array
{"id", "0"},
},
}},
{ "role", "assistant" },
{ "content", {} },
{ "tool_calls", tool_calls }
};
auto tool_call_message_with_id = json::parse(tool_call_message.dump());
tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
auto python_tool_call_message = json{
{ "role", "assistant" },
@ -322,6 +390,27 @@ static void test_template_output_parsers() {
inputs_tools_builtin.tools = json::array();
inputs_tools_builtin.tools.push_back(python_tool);
{
// Not supported yet
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
}
{
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools,
"<|START_THINKING|>I'm not so sure<|END_THINKING|>"
"<|START_ACTION|>[\n"
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
"]<|END_ACTION|>");
test_template(tmpl, end_tokens, text_message, tools,
"<|START_RESPONSE|>Hello, world!<|END_RESPONSE|>",
/* expect_grammar_triggered= */ false);
}
{
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens{ "<end_of_turn>" };
@ -362,11 +451,10 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(
tmpl, end_tokens, tool_call_message_with_id, tools,
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
/* skip_grammar_test= */ true);
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
}
{
const common_chat_template tmpl(
@ -388,7 +476,7 @@ static void test_template_output_parsers() {
inputs_tools)
.format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
@ -413,7 +501,7 @@ static void test_template_output_parsers() {
inputs_tools_builtin)
.format);
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
test_template(tmpl, end_tokens, python_tool_call_message, tools,
@ -428,7 +516,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
@ -440,7 +528,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<function=special_function>{\"arg1\": 1}</function>");
}
@ -455,7 +543,7 @@ static void test_template_output_parsers() {
test_template(tmpl, end_tokens, text_message, {},
"all\n"
"Hello, world!",
/* skip_grammar_test= */ true);
/* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"special_function\n"
"{\"arg1\": 1}");
@ -467,7 +555,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
}
@ -478,7 +566,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
"```json\n"