align Command R7B w/ --think / reasoning_content behaviour
This commit is contained in:
parent
3841a163ef
commit
e6d9b52480
9 changed files with 176 additions and 87 deletions
|
@ -24,7 +24,7 @@ static common_chat_msg msg_from_json(const json & message) {
|
|||
ret.content = message.at("content");
|
||||
}
|
||||
if (message.contains("tool_plan")) {
|
||||
ret.tool_plan = message.at("tool_plan");
|
||||
ret.reasoning_content = message.at("tool_plan");
|
||||
}
|
||||
if (message.contains("reasoning_content")) {
|
||||
ret.reasoning_content = message.at("reasoning_content");
|
||||
|
@ -109,7 +109,6 @@ static void assert_msg_equals(const common_chat_msg & expected, const common_cha
|
|||
assert_equals(expected.role, actual.role);
|
||||
assert_equals(expected.content, actual.content);
|
||||
assert_equals(expected.reasoning_content, actual.reasoning_content);
|
||||
assert_equals(expected.tool_plan, actual.tool_plan);
|
||||
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
|
||||
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
|
||||
const auto & expected_tool_call = expected.tool_calls[i];
|
||||
|
@ -181,13 +180,15 @@ struct delta_data {
|
|||
|
||||
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
||||
const json & user_message, const json & delta_message, const json & tools,
|
||||
const json & tool_choice) {
|
||||
const json & tool_choice,
|
||||
bool think = false) {
|
||||
common_chat_inputs inputs;
|
||||
inputs.parallel_tool_calls = true;
|
||||
inputs.messages = json::array();
|
||||
inputs.messages.push_back(user_message);
|
||||
inputs.tools = tools;
|
||||
inputs.tool_choice = tool_choice;
|
||||
inputs.think = think;
|
||||
auto params_prefix = common_chat_params_init(tmpl, inputs);
|
||||
|
||||
inputs.messages.push_back(delta_message);
|
||||
|
@ -229,7 +230,8 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
|||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
||||
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
|
||||
bool expect_grammar_triggered = true,
|
||||
bool test_grammar_if_triggered = true) {
|
||||
bool test_grammar_if_triggered = true,
|
||||
bool think = false) {
|
||||
common_chat_msg expected_msg = msg_from_json(test_message);
|
||||
|
||||
auto user_message = json{
|
||||
|
@ -238,7 +240,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
};
|
||||
|
||||
for (const auto & tool_choice : json({ "auto", "required" })) {
|
||||
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
|
||||
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice, think);
|
||||
if (!expected_delta.empty()) {
|
||||
assert_equals(expected_delta, data.delta);
|
||||
}
|
||||
|
@ -297,10 +299,14 @@ static void test_template_output_parsers() {
|
|||
{ "role", "assistant" },
|
||||
{ "content", "Hello, world!\nWhat's up?" },
|
||||
};
|
||||
json message_assist_thoughts_unparsed {
|
||||
json message_assist_thoughts_unparsed_think {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "<think>I'm thinking</think>Hello, world!\nWhat's up?" },
|
||||
};
|
||||
json message_assist_thoughts_unparsed_r7b {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" },
|
||||
};
|
||||
json message_assist_thoughts {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "Hello, world!\nWhat's up?" },
|
||||
|
@ -371,7 +377,6 @@ static void test_template_output_parsers() {
|
|||
json message_assist_call_idx {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_plan", "I'm not so sure"},
|
||||
{ "tool_calls", {
|
||||
{
|
||||
{ "type", "function" },
|
||||
|
@ -387,6 +392,8 @@ static void test_template_output_parsers() {
|
|||
{ "content", {} },
|
||||
{ "tool_calls", tool_calls }
|
||||
};
|
||||
json message_assist_call_tool_plan_idx = message_assist_call_idx;
|
||||
message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking";
|
||||
|
||||
auto python_message_assist_call = json{
|
||||
{ "role", "assistant" },
|
||||
|
@ -448,14 +455,52 @@ static void test_template_output_parsers() {
|
|||
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_THINK, common_chat_params_init(tmpl, inputs_tools_think).format);
|
||||
|
||||
assert_msg_equals(msg_from_json(message_assist),
|
||||
common_chat_parse(
|
||||
"Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
assert_msg_equals(msg_from_json(message_assist),
|
||||
common_chat_parse(
|
||||
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
assert_msg_equals(msg_from_json(message_assist),
|
||||
common_chat_parse(
|
||||
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
|
||||
common_chat_parse(
|
||||
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
||||
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
|
||||
common_chat_parse(
|
||||
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
||||
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
||||
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts),
|
||||
common_chat_parse(
|
||||
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
||||
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B_THINK));
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist_call_idx, tools,
|
||||
"<|START_THINKING|>I'm not so sure<|END_THINKING|>"
|
||||
"<|START_THINKING|><|END_THINKING|>"
|
||||
"<|START_ACTION|>[\n"
|
||||
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
||||
"]<|END_ACTION|>");
|
||||
test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools,
|
||||
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
||||
"<|START_ACTION|>[\n"
|
||||
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
||||
"]<|END_ACTION|>",
|
||||
/* expect_grammar_triggered= */ true,
|
||||
/* test_grammar_if_triggered= */ true,
|
||||
/* think= */ true);
|
||||
test_template(tmpl, end_tokens, message_assist, tools,
|
||||
"<|START_RESPONSE|>Hello, world!\n"
|
||||
"What's up?<|END_RESPONSE|>",
|
||||
|
@ -616,12 +661,17 @@ static void test_template_output_parsers() {
|
|||
"<s>", "</s>");
|
||||
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK, common_chat_params_init(tmpl, inputs_tools_think).format);
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts),
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
|
||||
// test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
// "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
// "```json\n"
|
||||
|
@ -637,12 +687,17 @@ static void test_template_output_parsers() {
|
|||
"<s>", "</s>");
|
||||
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK, common_chat_params_init(tmpl, inputs_tools_think).format);
|
||||
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts),
|
||||
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
|
||||
|
||||
assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed),
|
||||
common_chat_parse(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue