--think to force any model to return reasoning_content (or just parse <think> for deepseek r1)
This commit is contained in:
parent
5d60cebbcc
commit
9d7c3cc51b
9 changed files with 306 additions and 145 deletions
|
@ -289,11 +289,19 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|||
}
|
||||
|
||||
static void test_template_output_parsers() {
|
||||
json text_message {
|
||||
json message_user {
|
||||
{ "role", "user" },
|
||||
{ "content", "Hey there!" },
|
||||
};
|
||||
json message_assist {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "Hello, world!\nWhat's up?" },
|
||||
};
|
||||
json text_reasoning_message {
|
||||
json message_assist_thoughts_unparsed {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "<think>I'm thinking</think>Hello, world!\nWhat's up?" },
|
||||
};
|
||||
json message_assist_thoughts {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "Hello, world!\nWhat's up?" },
|
||||
{ "reasoning_content", "I'm thinking" },
|
||||
|
@ -303,7 +311,7 @@ static void test_template_output_parsers() {
|
|||
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
|
||||
}});
|
||||
|
||||
json tool_call_message {
|
||||
json message_assist_call {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_calls", {
|
||||
|
@ -316,7 +324,7 @@ static void test_template_output_parsers() {
|
|||
},
|
||||
}},
|
||||
};
|
||||
json tool_call_reasoning_message = {
|
||||
json message_assist_call_thoughts = {
|
||||
{ "role", "assistant" },
|
||||
{ "content", nullptr },
|
||||
{ "reasoning_content", "I'm\nthinking" },
|
||||
|
@ -330,7 +338,20 @@ static void test_template_output_parsers() {
|
|||
},
|
||||
}},
|
||||
};
|
||||
json tool_call_message_with_id {
|
||||
json message_assist_call_thoughts_unparsed = {
|
||||
{ "role", "assistant" },
|
||||
{ "content", "<think>I'm\nthinking</think>" },
|
||||
{ "tool_calls", {
|
||||
{
|
||||
{ "type", "function" },
|
||||
{ "function", {
|
||||
{ "name", "special_function" },
|
||||
{ "arguments", "{\"arg1\": 1}" },
|
||||
}},
|
||||
},
|
||||
}},
|
||||
};
|
||||
json message_assist_call_id {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_calls", {
|
||||
|
@ -347,7 +368,7 @@ static void test_template_output_parsers() {
|
|||
{ "content", {} },
|
||||
{ "tool_calls", tool_calls }
|
||||
};
|
||||
json tool_call_plan_message_with_idx {
|
||||
json message_assist_call_idx {
|
||||
{ "role", "assistant"},
|
||||
{ "content", {}},
|
||||
{ "tool_plan", "I'm not so sure"},
|
||||
|
@ -367,7 +388,7 @@ static void test_template_output_parsers() {
|
|||
{ "tool_calls", tool_calls }
|
||||
};
|
||||
|
||||
auto python_tool_call_message = json{
|
||||
auto python_message_assist_call = json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", {} },
|
||||
{ "tool_calls", json{ {
|
||||
|
@ -382,7 +403,7 @@ static void test_template_output_parsers() {
|
|||
} },
|
||||
} } }
|
||||
};
|
||||
auto code_interpreter_tool_call_message = json{
|
||||
auto code_interpreter_message_assist_call = json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", {} },
|
||||
{ "tool_calls", json{ {
|
||||
|
@ -399,17 +420,24 @@ static void test_template_output_parsers() {
|
|||
};
|
||||
|
||||
common_chat_inputs inputs_no_tools;
|
||||
inputs_no_tools.messages = {
|
||||
{ { "role", "user" }, { "content", "Hey\nThere" } }
|
||||
};
|
||||
inputs_no_tools.messages = json::array({message_user});
|
||||
|
||||
common_chat_inputs inputs_tools = inputs_no_tools;
|
||||
inputs_tools.tools = json::array();
|
||||
inputs_tools.tools.push_back(special_function_tool);
|
||||
common_chat_inputs inputs_no_tools_think;
|
||||
inputs_no_tools_think.messages = json::array({message_user});
|
||||
inputs_no_tools_think.think = true;
|
||||
|
||||
common_chat_inputs inputs_tools_builtin = inputs_no_tools;
|
||||
inputs_tools_builtin.tools = json::array();
|
||||
inputs_tools_builtin.tools.push_back(python_tool);
|
||||
common_chat_inputs inputs_tools;
|
||||
inputs_tools.messages = json::array({message_user});
|
||||
inputs_tools.tools = json::array({special_function_tool});
|
||||
|
||||
common_chat_inputs inputs_tools_think;
|
||||
inputs_tools_think.messages = json::array({message_user});
|
||||
inputs_tools_think.tools = json::array({special_function_tool});
|
||||
inputs_tools_think.think = true;
|
||||
|
||||
common_chat_inputs inputs_tools_builtin;
|
||||
inputs_tools_builtin.messages = json::array({message_user});
|
||||
inputs_tools_builtin.tools = json::array({python_tool});
|
||||
|
||||
{
|
||||
// Not supported yet
|
||||
|
@ -423,12 +451,12 @@ static void test_template_output_parsers() {
|
|||
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools,
|
||||
test_template(tmpl, end_tokens, message_assist_call_idx, tools,
|
||||
"<|START_THINKING|>I'm not so sure<|END_THINKING|>"
|
||||
"<|START_ACTION|>[\n"
|
||||
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
||||
"]<|END_ACTION|>");
|
||||
test_template(tmpl, end_tokens, text_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist, tools,
|
||||
"<|START_RESPONSE|>Hello, world!\n"
|
||||
"What's up?<|END_RESPONSE|>",
|
||||
/* expect_grammar_triggered= */ false);
|
||||
|
@ -448,12 +476,12 @@ static void test_template_output_parsers() {
|
|||
|
||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
|
||||
assert_msg_equals(msg_from_json(text_message),
|
||||
assert_msg_equals(msg_from_json(message_assist),
|
||||
common_chat_parse("{\n"
|
||||
" \"response\": \"Hello, world!\\nWhat's up?\"\n"
|
||||
"}",
|
||||
common_chat_params_init(tmpl, inputs_tools).format));
|
||||
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
test_template(tmpl, end_tokens, message_assist_call_id, tools,
|
||||
"{\n"
|
||||
" \"tool_calls\": [\n"
|
||||
" {\n"
|
||||
|
@ -473,9 +501,9 @@ static void test_template_output_parsers() {
|
|||
|
||||
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(
|
||||
tmpl, end_tokens, tool_call_message_with_id, tools,
|
||||
tmpl, end_tokens, message_assist_call_id, tools,
|
||||
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
|
||||
}
|
||||
{
|
||||
|
@ -498,12 +526,12 @@ static void test_template_output_parsers() {
|
|||
inputs_tools)
|
||||
.format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>");
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, python_message_assist_call, tools,
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
||||
"</tool_call>");
|
||||
|
@ -523,12 +551,12 @@ static void test_template_output_parsers() {
|
|||
inputs_tools_builtin)
|
||||
.format);
|
||||
|
||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||
// test_template(tmpl, end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, code_interpreter_message_assist_call, llama_3_1_tools,
|
||||
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
||||
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, python_message_assist_call, tools,
|
||||
"<|python_tag|>python.call(code=\"print('hey')\")");
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
|
@ -538,8 +566,8 @@ static void test_template_output_parsers() {
|
|||
|
||||
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
|
@ -550,8 +578,8 @@ static void test_template_output_parsers() {
|
|||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"<function=special_function>{\"arg1\": 1}</function>");
|
||||
}
|
||||
{
|
||||
|
@ -562,12 +590,12 @@ static void test_template_output_parsers() {
|
|||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, {},
|
||||
test_template(tmpl, end_tokens, message_assist, {},
|
||||
"all\n"
|
||||
"Hello, world!\n"
|
||||
"What's up?",
|
||||
/* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"special_function\n"
|
||||
"{\"arg1\": 1}");
|
||||
}
|
||||
|
@ -578,8 +606,8 @@ static void test_template_output_parsers() {
|
|||
|
||||
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
||||
}
|
||||
{
|
||||
|
@ -590,10 +618,11 @@ static void test_template_output_parsers() {
|
|||
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, text_reasoning_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
assert_msg_equals(msg_from_json(text_reasoning_message), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
// test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
|
||||
// test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
// "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
// "```json\n"
|
||||
// "{\"arg1\": 1}\n"
|
||||
|
@ -610,11 +639,12 @@ static void test_template_output_parsers() {
|
|||
|
||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, text_reasoning_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
assert_msg_equals(msg_from_json(text_reasoning_message), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
assert_msg_equals(msg_from_json(message_assist_thoughts), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
|
||||
|
||||
assert_msg_equals(msg_from_json(tool_call_reasoning_message),
|
||||
assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed),
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think>\n\n"
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
|
@ -622,7 +652,15 @@ static void test_template_output_parsers() {
|
|||
"{\"arg1\": 1}\n"
|
||||
"```<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
assert_msg_equals(msg_from_json(message_assist_call_thoughts),
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think>\n\n"
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
"```json\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
"```<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
|
||||
test_template(tmpl, end_tokens, message_assist_call, tools,
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||
"```json\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue