tool-call
: support Command R7B (+ return tool_plan "thoughts" in API) (#11585)
* `tool-call`: support Command R7B (w/ tool_plan return) * `tool-call`: cleaner preservation of tokens + warn when likely bad chat template override * `tool-call`: test cleanup / handle lazy grammar triggers
This commit is contained in:
parent
69804487e0
commit
bfcce4d693
8 changed files with 420 additions and 56 deletions
|
@ -22,9 +22,13 @@ static common_chat_msg msg_from_json(const json & message) {
|
|||
"assistant",
|
||||
"",
|
||||
{},
|
||||
/* .tool_plan = */ "",
|
||||
};
|
||||
if (message.contains("content") && !message.at("content").is_null()) {
|
||||
ret.content = message.at("content").get<std::string>();
|
||||
ret.content = message.at("content");
|
||||
}
|
||||
if (message.contains("tool_plan")) {
|
||||
ret.tool_plan = message.at("tool_plan");
|
||||
}
|
||||
auto has_tool_calls = message.contains("tool_calls");
|
||||
if (has_tool_calls) {
|
||||
|
@ -171,8 +175,7 @@ const json llama_3_1_tools = { special_function_tool, code_interpreter_too
|
|||
|
||||
struct delta_data {
|
||||
std::string delta;
|
||||
std::string grammar;
|
||||
common_chat_format format;
|
||||
common_chat_params params;
|
||||
};
|
||||
|
||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
||||
|
@ -214,7 +217,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
|||
break;
|
||||
}
|
||||
}
|
||||
return { delta, params_full.grammar, params_full.format };
|
||||
return { delta, params_full };
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -224,7 +227,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
|||
*/
|
||||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
||||
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
|
||||
bool skip_grammar_test = false, bool skip_parser_test = false) {
|
||||
bool expect_grammar_triggered = true) {
|
||||
common_chat_msg expected_msg = msg_from_json(test_message);
|
||||
|
||||
auto user_message = json{
|
||||
|
@ -238,45 +241,110 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
assert_equals(expected_delta, data.delta);
|
||||
}
|
||||
|
||||
if (!skip_parser_test) {
|
||||
const auto msg = common_chat_parse(data.delta, data.format);
|
||||
if (expect_grammar_triggered) {
|
||||
const auto msg = common_chat_parse(data.delta, data.params.format);
|
||||
assert_msg_equals(expected_msg, msg);
|
||||
}
|
||||
|
||||
if (!expected_msg.tool_calls.empty()) {
|
||||
GGML_ASSERT(!data.grammar.empty());
|
||||
GGML_ASSERT(!data.params.grammar.empty());
|
||||
}
|
||||
if (!data.grammar.empty()) {
|
||||
auto grammar = build_grammar(data.grammar);
|
||||
if (!data.params.grammar.empty()) {
|
||||
auto grammar = build_grammar(data.params.grammar);
|
||||
if (!grammar) {
|
||||
throw std::runtime_error("Failed to build grammar");
|
||||
}
|
||||
// TODO: exercice lazy grammars + triggers here, instead of skipping the test
|
||||
if (!skip_grammar_test) {
|
||||
if (!match_string(data.delta, grammar.get())) {
|
||||
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
||||
"\n\nGrammar: " + data.grammar);
|
||||
auto earliest_trigger_pos = std::string::npos;
|
||||
auto constrained = data.delta;
|
||||
for (const auto & trigger : data.params.grammar_triggers) {
|
||||
auto pos = constrained.find(trigger.word);
|
||||
if (pos == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
if (pos > 0 && trigger.at_start) {
|
||||
fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
|
||||
continue;
|
||||
}
|
||||
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
|
||||
earliest_trigger_pos = pos;
|
||||
}
|
||||
}
|
||||
auto grammar_triggered = false;
|
||||
if (earliest_trigger_pos != std::string::npos) {
|
||||
constrained = constrained.substr(earliest_trigger_pos);
|
||||
grammar_triggered = true;
|
||||
}
|
||||
if (data.params.grammar_lazy) {
|
||||
assert_equals(expect_grammar_triggered, grammar_triggered);
|
||||
}
|
||||
|
||||
if (grammar_triggered && !match_string(constrained, grammar.get())) {
|
||||
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
||||
"\n\nGrammar: " + data.params.grammar);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void test_template_output_parsers() {
|
||||
auto text_message = json{
|
||||
json text_message {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "Hello, world!" },
|
||||
};
|
||||
auto tool_call_message = json{
|
||||
json tool_calls = json::array({{
|
||||
{ "type", "function" },
|
||||
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
|
||||
}});
|
||||
|
||||
json tool_call_message {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_calls", {
|
||||
{
|
||||
{ "type", "function" },
|
||||
{ "function", {
|
||||
{ "name", "special_function" },
|
||||
{ "arguments", "{\"arg1\": 1}" },
|
||||
}},
|
||||
},
|
||||
}},
|
||||
};
|
||||
json tool_call_message_with_id {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_calls", {
|
||||
{
|
||||
{ "type", "function" },
|
||||
{ "function", {
|
||||
{ "name", "special_function" },
|
||||
{ "arguments", "{\"arg1\": 1}" },
|
||||
}},
|
||||
{"id", "123456789"},
|
||||
},
|
||||
}},
|
||||
{ "role", "assistant" },
|
||||
{ "content", {} },
|
||||
{ "tool_calls", json{ {
|
||||
{ "type", "function" },
|
||||
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
|
||||
} } }
|
||||
{ "tool_calls", tool_calls }
|
||||
};
|
||||
json tool_call_plan_message_with_idx {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_plan", "I'm not so sure"},
|
||||
{ "tool_calls", {
|
||||
{
|
||||
{ "type", "function" },
|
||||
{ "function", {
|
||||
{ "name", "special_function" },
|
||||
{ "arguments", "{\"arg1\": 1}" },
|
||||
}},
|
||||
// Index of the tool call in the tool_calls array
|
||||
{"id", "0"},
|
||||
},
|
||||
}},
|
||||
{ "role", "assistant" },
|
||||
{ "content", {} },
|
||||
{ "tool_calls", tool_calls }
|
||||
};
|
||||
auto tool_call_message_with_id = json::parse(tool_call_message.dump());
|
||||
tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
|
||||
|
||||
auto python_tool_call_message = json{
|
||||
{ "role", "assistant" },
|
||||
|
@ -322,6 +390,27 @@ static void test_template_output_parsers() {
|
|||
inputs_tools_builtin.tools = json::array();
|
||||
inputs_tools_builtin.tools.push_back(python_tool);
|
||||
|
||||
{
|
||||
// Not supported yet
|
||||
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
|
||||
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools,
|
||||
"<|START_THINKING|>I'm not so sure<|END_THINKING|>"
|
||||
"<|START_ACTION|>[\n"
|
||||
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
||||
"]<|END_ACTION|>");
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
"<|START_RESPONSE|>Hello, world!<|END_RESPONSE|>",
|
||||
/* expect_grammar_triggered= */ false);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens{ "<end_of_turn>" };
|
||||
|
@ -362,11 +451,10 @@ static void test_template_output_parsers() {
|
|||
|
||||
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
||||
test_template(
|
||||
tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
|
||||
/* skip_grammar_test= */ true);
|
||||
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(
|
||||
|
@ -388,7 +476,7 @@ static void test_template_output_parsers() {
|
|||
inputs_tools)
|
||||
.format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
|
@ -413,7 +501,7 @@ static void test_template_output_parsers() {
|
|||
inputs_tools_builtin)
|
||||
.format);
|
||||
|
||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
|
@ -428,7 +516,7 @@ static void test_template_output_parsers() {
|
|||
|
||||
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
|
@ -440,7 +528,7 @@ static void test_template_output_parsers() {
|
|||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<function=special_function>{\"arg1\": 1}</function>");
|
||||
}
|
||||
|
@ -455,7 +543,7 @@ static void test_template_output_parsers() {
|
|||
test_template(tmpl, end_tokens, text_message, {},
|
||||
"all\n"
|
||||
"Hello, world!",
|
||||
/* skip_grammar_test= */ true);
|
||||
/* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"special_function\n"
|
||||
"{\"arg1\": 1}");
|
||||
|
@ -467,7 +555,7 @@ static void test_template_output_parsers() {
|
|||
|
||||
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
||||
}
|
||||
|
@ -478,7 +566,7 @@ static void test_template_output_parsers() {
|
|||
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
"```json\n"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue