align Command R7B w/ --think / reasoning_content behaviour

This commit is contained in:
Olivier Chafik 2025-02-05 15:47:37 +00:00
parent 3841a163ef
commit e6d9b52480
9 changed files with 176 additions and 87 deletions

View file

@ -24,7 +24,7 @@ static common_chat_msg msg_from_json(const json & message) {
ret.content = message.at("content");
}
if (message.contains("tool_plan")) {
ret.tool_plan = message.at("tool_plan");
ret.reasoning_content = message.at("tool_plan");
}
if (message.contains("reasoning_content")) {
ret.reasoning_content = message.at("reasoning_content");
@ -109,7 +109,6 @@ static void assert_msg_equals(const common_chat_msg & expected, const common_cha
assert_equals(expected.role, actual.role);
assert_equals(expected.content, actual.content);
assert_equals(expected.reasoning_content, actual.reasoning_content);
assert_equals(expected.tool_plan, actual.tool_plan);
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
const auto & expected_tool_call = expected.tool_calls[i];
@ -181,13 +180,15 @@ struct delta_data {
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
const json & user_message, const json & delta_message, const json & tools,
const json & tool_choice) {
const json & tool_choice,
bool think = false) {
common_chat_inputs inputs;
inputs.parallel_tool_calls = true;
inputs.messages = json::array();
inputs.messages.push_back(user_message);
inputs.tools = tools;
inputs.tool_choice = tool_choice;
inputs.think = think;
auto params_prefix = common_chat_params_init(tmpl, inputs);
inputs.messages.push_back(delta_message);
@ -229,7 +230,8 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
bool expect_grammar_triggered = true,
bool test_grammar_if_triggered = true) {
bool test_grammar_if_triggered = true,
bool think = false) {
common_chat_msg expected_msg = msg_from_json(test_message);
auto user_message = json{
@ -238,7 +240,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
};
for (const auto & tool_choice : json({ "auto", "required" })) {
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice, think);
if (!expected_delta.empty()) {
assert_equals(expected_delta, data.delta);
}
@ -297,10 +299,14 @@ static void test_template_output_parsers() {
{ "role", "assistant" },
{ "content", "Hello, world!\nWhat's up?" },
};
json message_assist_thoughts_unparsed {
json message_assist_thoughts_unparsed_think {
{ "role", "assistant" },
{ "content", "<think>I'm thinking</think>Hello, world!\nWhat's up?" },
};
json message_assist_thoughts_unparsed_r7b {
{ "role", "assistant" },
{ "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" },
};
json message_assist_thoughts {
{ "role", "assistant" },
{ "content", "Hello, world!\nWhat's up?" },
@ -371,7 +377,6 @@ static void test_template_output_parsers() {
json message_assist_call_idx {
{ "role", "assistant"},
{ "content", {}},
{ "tool_plan", "I'm not so sure"},
{ "tool_calls", {
{
{ "type", "function" },
@ -387,6 +392,8 @@ static void test_template_output_parsers() {
{ "content", {} },
{ "tool_calls", tool_calls }
};
json message_assist_call_tool_plan_idx = message_assist_call_idx;
message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking";
auto python_message_assist_call = json{
{ "role", "assistant" },
@ -448,14 +455,52 @@ static void test_template_output_parsers() {
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_THINK, common_chat_params_init(tmpl, inputs_tools_think).format);
assert_msg_equals(msg_from_json(message_assist),
common_chat_parse(
"Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_COMMAND_R7B));
assert_msg_equals(msg_from_json(message_assist),
common_chat_parse(
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B));
assert_msg_equals(msg_from_json(message_assist),
common_chat_parse(
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B));
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
common_chat_parse(
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B));
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
common_chat_parse(
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B));
assert_msg_equals(msg_from_json(message_assist_thoughts),
common_chat_parse(
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B_THINK));
test_template(tmpl, end_tokens, message_assist_call_idx, tools,
"<|START_THINKING|>I'm not so sure<|END_THINKING|>"
"<|START_THINKING|><|END_THINKING|>"
"<|START_ACTION|>[\n"
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
"]<|END_ACTION|>");
test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools,
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
"<|START_ACTION|>[\n"
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
"]<|END_ACTION|>",
/* expect_grammar_triggered= */ true,
/* test_grammar_if_triggered= */ true,
/* think= */ true);
test_template(tmpl, end_tokens, message_assist, tools,
"<|START_RESPONSE|>Hello, world!\n"
"What's up?<|END_RESPONSE|>",
@ -616,12 +661,17 @@ static void test_template_output_parsers() {
"<s>", "</s>");
std::vector<std::string> end_tokens{ "<end▁of▁sentence>" };
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK, common_chat_params_init(tmpl, inputs_tools_think).format);
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1));
assert_msg_equals(msg_from_json(message_assist_thoughts), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
assert_msg_equals(msg_from_json(message_assist_thoughts),
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
// test_template(tmpl, end_tokens, message_assist_call, tools,
// "<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
// "```json\n"
@ -637,12 +687,17 @@ static void test_template_output_parsers() {
"<s>", "</s>");
std::vector<std::string> end_tokens{ "<end▁of▁sentence>" };
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK, common_chat_params_init(tmpl, inputs_tools_think).format);
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1));
assert_msg_equals(msg_from_json(message_assist_thoughts), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
assert_msg_equals(msg_from_json(message_assist_thoughts),
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed),
common_chat_parse(