fix / test parsing of r1 parser
This commit is contained in:
parent
9a6847c857
commit
a682d1216d
2 changed files with 33 additions and 19 deletions
|
@ -606,8 +606,8 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
||||||
// Fix up tool call delta example added by Minja
|
// Fix up tool call delta example added by Minja
|
||||||
prompt = std::regex_replace(
|
prompt = std::regex_replace(
|
||||||
prompt,
|
prompt,
|
||||||
std::regex("<|tool▁call▁end|>[\\s\\r\\n]*<|User|>"),
|
std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"),
|
||||||
"<|tool▁call▁end|><|tool▁calls▁end|><|User|>");
|
"$1<|tool▁calls▁end|><|end▁of▁sentence|>$2");
|
||||||
}
|
}
|
||||||
data.prompt = prompt;
|
data.prompt = prompt;
|
||||||
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
||||||
|
@ -617,7 +617,7 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input)
|
||||||
static std::regex trigger_regex("(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)?");
|
static std::regex trigger_regex("(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)?");
|
||||||
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||||
static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
||||||
static std::regex think_regex(R"(<think>([\s\S\n]*)(</think>)?([\s\S\r\n]*))");
|
static std::regex think_regex("<think>([\\s\\S\\n]*?)</think>([\\s\\S\\r\\n]*)");
|
||||||
auto msg = parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
|
auto msg = parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
if (std::regex_match(msg.content, match, think_regex)) {
|
if (std::regex_match(msg.content, match, think_regex)) {
|
||||||
|
|
|
@ -108,6 +108,8 @@ static std::string dump(const json & j) {
|
||||||
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
|
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
|
||||||
assert_equals(expected.role, actual.role);
|
assert_equals(expected.role, actual.role);
|
||||||
assert_equals(expected.content, actual.content);
|
assert_equals(expected.content, actual.content);
|
||||||
|
assert_equals(expected.thoughts, actual.thoughts);
|
||||||
|
assert_equals(expected.tool_plan, actual.tool_plan);
|
||||||
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
|
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
|
||||||
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
|
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
|
||||||
const auto & expected_tool_call = expected.tool_calls[i];
|
const auto & expected_tool_call = expected.tool_calls[i];
|
||||||
|
@ -226,7 +228,8 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
||||||
*/
|
*/
|
||||||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
||||||
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
|
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
|
||||||
bool expect_grammar_triggered = true) {
|
bool expect_grammar_triggered = true,
|
||||||
|
bool test_grammar_if_triggered = true) {
|
||||||
common_chat_msg expected_msg = msg_from_json(test_message);
|
common_chat_msg expected_msg = msg_from_json(test_message);
|
||||||
|
|
||||||
auto user_message = json{
|
auto user_message = json{
|
||||||
|
@ -277,7 +280,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
assert_equals(expect_grammar_triggered, grammar_triggered);
|
assert_equals(expect_grammar_triggered, grammar_triggered);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (grammar_triggered && !match_string(constrained, grammar.get())) {
|
if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
|
||||||
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
||||||
"\n\nGrammar: " + data.params.grammar);
|
"\n\nGrammar: " + data.params.grammar);
|
||||||
}
|
}
|
||||||
|
@ -290,6 +293,11 @@ static void test_template_output_parsers() {
|
||||||
{ "role", "assistant" },
|
{ "role", "assistant" },
|
||||||
{ "content", "Hello, world!" },
|
{ "content", "Hello, world!" },
|
||||||
};
|
};
|
||||||
|
json text_thoughts_message {
|
||||||
|
{ "role", "assistant" },
|
||||||
|
{ "content", "Hello, world!" },
|
||||||
|
{ "thoughts", "I'm thinking" },
|
||||||
|
};
|
||||||
json tool_calls = json::array({{
|
json tool_calls = json::array({{
|
||||||
{ "type", "function" },
|
{ "type", "function" },
|
||||||
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
|
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
|
||||||
|
@ -389,6 +397,26 @@ static void test_template_output_parsers() {
|
||||||
inputs_tools_builtin.tools = json::array();
|
inputs_tools_builtin.tools = json::array();
|
||||||
inputs_tools_builtin.tools.push_back(python_tool);
|
inputs_tools_builtin.tools.push_back(python_tool);
|
||||||
|
|
||||||
|
{
|
||||||
|
// Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt.
|
||||||
|
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
|
||||||
|
"<s>", "</s>");
|
||||||
|
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
||||||
|
|
||||||
|
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
||||||
|
|
||||||
|
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
||||||
|
test_template(tmpl, end_tokens, text_thoughts_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
||||||
|
assert_msg_equals(msg_from_json(text_thoughts_message), common_chat_parse("<think>I'm thinking</think>Hello, world!", COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
||||||
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
||||||
|
"```json\n"
|
||||||
|
"{\"arg1\": 1}\n"
|
||||||
|
// Look what's not here: <|tool▁calls▁end|> (also missing the <|end▁of▁sentence|>, but that is removed lazily by the test's delta logic)
|
||||||
|
"```<|tool▁call▁end|>",
|
||||||
|
/* expect_grammar_triggered= */ true,
|
||||||
|
/* test_grammar_if_triggered= */ false);
|
||||||
|
}
|
||||||
{
|
{
|
||||||
// Not supported yet
|
// Not supported yet
|
||||||
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
|
||||||
|
@ -558,20 +586,6 @@ static void test_template_output_parsers() {
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
||||||
}
|
}
|
||||||
{
|
|
||||||
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
|
|
||||||
"<s>", "</s>");
|
|
||||||
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
|
||||||
|
|
||||||
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
||||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
|
||||||
"```json\n"
|
|
||||||
"{\"arg1\": 1}\n"
|
|
||||||
"```<|tool▁call▁end|>");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue