align Command R7B w/ --think / reasoning_content behaviour

This commit is contained in:
Olivier Chafik 2025-02-05 15:47:37 +00:00
parent 3841a163ef
commit e6d9b52480
9 changed files with 176 additions and 87 deletions

View file

@ -1978,7 +1978,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
add_opt(common_arg( add_opt(common_arg(
{"--think"}, {"--think"},
"*experimental* thinking mode (default: disabled)\n" "*experimental* thinking mode (default: disabled)\n"
"returns reasoning_content in messages, forcing model to think unless it supports native <think> tags (DeepSeek R1)\n" "returns reasoning_content in messages, forcing model to think unless it supports native <think> tags (DeepSeek R1, Command R7B)\n"
"only supported for non-streamed responses", "only supported for non-streamed responses",
[](common_params & params) { [](common_params & params) {
params.think = true; params.think = true;

View file

@ -316,7 +316,7 @@ class chat_template {
auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role; auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools; auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples; auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples && caps_.supports_tool_calls;
auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls; auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses; auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments; auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;

View file

@ -12,12 +12,13 @@ std::string common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK: return "DeepSeek R1 (extract <think>)"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK: return "DeepSeek R1 (extract reasoning_content)";
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_COMMAND_R7B_THINK: return "Command R7B (extract reasoning_content)";
default: default:
throw std::runtime_error("Unknown chat format"); throw std::runtime_error("Unknown chat format");
} }
@ -469,22 +470,49 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
"<|END_THINKING|>", "<|END_THINKING|>",
"<|END_ACTION|>", "<|END_ACTION|>",
}; };
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); auto adjusted_messages = json::array();
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; for (const auto & msg : inputs.messages) {
auto has_reasoning_content = msg.contains("reasoning_content") && msg["reasoning_content"].is_string();
auto has_tool_calls = msg.contains("tool_calls") && msg["tool_calls"].is_array();
if (has_reasoning_content && has_tool_calls) {
auto adjusted_message = msg;
adjusted_message["tool_plan"] = msg["reasoning_content"];
adjusted_message.erase("reasoning_content");
adjusted_messages.push_back(adjusted_message);
} else {
adjusted_messages.push_back(msg);
}
}
// } else {
// adjusted_messages = inputs.messages;
// }
data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
data.format = inputs.think ? COMMON_CHAT_FORMAT_COMMAND_R7B_THINK : COMMON_CHAT_FORMAT_COMMAND_R7B;
return data; return data;
} }
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) { static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool think) {
static std::regex response_regex("<\\|START_RESPONSE\\|>([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>"); static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)");
static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>"); static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
std::smatch match; std::smatch match;
common_chat_msg result; common_chat_msg result;
result.role = "assistant"; result.role = "assistant";
if (std::regex_match(input, match, response_regex)) {
result.content = match[1].str(); std::string rest = input;
} else if (std::regex_match(input, match, thought_action_regex)) {
result.tool_plan = match[1].str(); if (std::regex_match(rest, match, thought_regex)) {
auto actions_str = match[2].str(); if (think) {
result.reasoning_content = match[2].str();
} else if (!match[2].str().empty()) {
// Let the unparsed thinking tags through in content only if their insides aren't empty.
result.content = match[1].str();
}
rest = match[3].str();
}
if (std::regex_match(rest, match, action_regex)) {
auto actions_str = match[1].str();
auto actions = json::parse(actions_str); auto actions = json::parse(actions_str);
for (const auto & action : actions) { for (const auto & action : actions) {
result.tool_calls.push_back({ result.tool_calls.push_back({
@ -493,9 +521,12 @@ static common_chat_msg common_chat_parse_command_r7b(const std::string & input)
/* .id = */ action["tool_call_id"], /* .id = */ action["tool_call_id"],
}); });
} }
} else if (std::regex_match(rest, match, response_regex)) {
auto response = match[1].str();
result.content += response;
} else { } else {
LOG_ERR("Failed to parse command_r output"); LOG_ERR("Failed to parse command_r output");
result.content = input; result.content += rest;
} }
return result; return result;
} }
@ -1038,6 +1069,11 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
return common_chat_params_init_deepseek_r1(tmpl, inputs); return common_chat_params_init_deepseek_r1(tmpl, inputs);
} }
// Command R7B: : use handler in all cases except json schema (thinking / tools).
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
return common_chat_params_init_command_r7b(tmpl, inputs);
}
// Use generic handler when forcing thoughts or JSON schema for final output // Use generic handler when forcing thoughts or JSON schema for final output
// TODO: support thinking mode and/or JSON schema in handlers below this. // TODO: support thinking mode and/or JSON schema in handlers below this.
if (inputs.think || (!inputs.tools.is_null() && inputs.json_schema.is_object())) { if (inputs.think || (!inputs.tools.is_null() && inputs.json_schema.is_object())) {
@ -1081,11 +1117,6 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
return common_chat_params_init_mistral_nemo(tmpl, inputs); return common_chat_params_init_mistral_nemo(tmpl, inputs);
} }
// Command R7B (w/ tools)
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
return common_chat_params_init_command_r7b(tmpl, inputs);
}
// Generic fallback // Generic fallback
return common_chat_params_init_generic(tmpl, inputs); return common_chat_params_init_generic(tmpl, inputs);
} }
@ -1123,7 +1154,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
return common_chat_parse_firefunction_v2(input); return common_chat_parse_firefunction_v2(input);
case COMMON_CHAT_FORMAT_COMMAND_R7B: case COMMON_CHAT_FORMAT_COMMAND_R7B:
return common_chat_parse_command_r7b(input); return common_chat_parse_command_r7b(input, /* think= */ false);
case COMMON_CHAT_FORMAT_COMMAND_R7B_THINK:
return common_chat_parse_command_r7b(input, /* think= */ true);
default: default:
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format)); throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
} }

View file

@ -35,6 +35,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO, COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B, COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_THINK,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
}; };

View file

@ -625,7 +625,6 @@ struct common_chat_msg {
std::string content; std::string content;
std::vector<common_tool_call> tool_calls; std::vector<common_tool_call> tool_calls;
std::string reasoning_content = ""; std::string reasoning_content = "";
std::string tool_plan = "";
}; };
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

View file

@ -127,6 +127,8 @@ The project is under active development, and we are [looking for feedback and co
| `--grammar-file FNAME` | file to read grammar from | | `--grammar-file FNAME` | file to read grammar from |
| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | | `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
| `--jinja` | Enable experimental Jinja templating engine (required for tool use) | | `--jinja` | Enable experimental Jinja templating engine (required for tool use) |
| `--think` | Enable experimental thinking mode (extracts DeepSeek R1 & Command R7B's native thinking tags and forces any other model to think before responding, resulting thoughts are in the `reasoning_content` output field) (requires `--jinja`) |
--think
**Example-specific params** **Example-specific params**
@ -1223,10 +1225,10 @@ curl http://localhost:8080/v1/chat/completions \
# Native support for DeepSeek R1 works best w/ our own template (official template buggy) # Native support for DeepSeek R1 works best w/ our own template (official template buggy)
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \ llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L --think \
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M \ llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M --think \
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
# Native support requires the right template for these GGUFs: # Native support requires the right template for these GGUFs:
@ -1240,7 +1242,7 @@ curl http://localhost:8080/v1/chat/completions \
llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \ llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use ) --chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \ llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L --think \
--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use ) --chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
# Generic format support # Generic format support

View file

@ -748,9 +748,6 @@ struct server_task_result_cmpl_final : server_task_result {
if (!msg.reasoning_content.empty()) { if (!msg.reasoning_content.empty()) {
message["reasoning_content"] = msg.reasoning_content; message["reasoning_content"] = msg.reasoning_content;
} }
if (!msg.tool_plan.empty()) {
message["tool_plan"] = msg.tool_plan;
}
json choice { json choice {
{"finish_reason", finish_reason}, {"finish_reason", finish_reason},

View file

@ -274,43 +274,44 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("hf_repo,template_override", [ @pytest.mark.parametrize("think,hf_repo,template_override", [
("bartowski/c4ai-command-r7b-12-2024-GGUF:Q4_K_M", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")), (True, "bartowski/c4ai-command-r7b-12-2024-GGUF:Q4_K_M", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (False, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), (False, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (False, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), (False, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), (False, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), (False, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (False, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), (False, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), (False, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), (False, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), (False, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), (False, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), (False, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), (False, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), (False, "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), (False, "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (True, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), (False, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
]) ])
def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None): def test_weather(think: bool, hf_repo: str, template_override: Tuple[str, str | None] | None):
global server global server
n_predict = 512 n_predict = 512
server.think = think
server.n_slots = 1 server.n_slots = 1
server.jinja = True server.jinja = True
server.n_ctx = 8192 server.n_ctx = 8192
@ -488,44 +489,45 @@ def test_thoughts(n_predict: int, think: bool, expect_content: str | None, expec
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ @pytest.mark.parametrize("think,expected_arguments_override,hf_repo,template_override", [
(None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (True, None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"), (True, None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (False, None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), (False, None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), (False, None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), (False, None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
(None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (False, None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), (False, '{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
('{"code":"print("}', "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), (False, '{"code":"print("}', "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), (False, None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), (False, '{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), (False, '{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), (False, None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), (False, None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (False, None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), (False, None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), (False, None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), (False, None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), (False, None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), (False, None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
(None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), (False, None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
]) ])
def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): def test_hello_world(think: bool, expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server global server
server.n_slots = 1 server.n_slots = 1
server.jinja = True server.jinja = True
server.think = think
server.n_ctx = 8192 server.n_ctx = 8192
server.n_predict = 512 # High because of DeepSeek R1 server.n_predict = 512 # High because of DeepSeek R1
server.model_hf_repo = hf_repo server.model_hf_repo = hf_repo

View file

@ -24,7 +24,7 @@ static common_chat_msg msg_from_json(const json & message) {
ret.content = message.at("content"); ret.content = message.at("content");
} }
if (message.contains("tool_plan")) { if (message.contains("tool_plan")) {
ret.tool_plan = message.at("tool_plan"); ret.reasoning_content = message.at("tool_plan");
} }
if (message.contains("reasoning_content")) { if (message.contains("reasoning_content")) {
ret.reasoning_content = message.at("reasoning_content"); ret.reasoning_content = message.at("reasoning_content");
@ -109,7 +109,6 @@ static void assert_msg_equals(const common_chat_msg & expected, const common_cha
assert_equals(expected.role, actual.role); assert_equals(expected.role, actual.role);
assert_equals(expected.content, actual.content); assert_equals(expected.content, actual.content);
assert_equals(expected.reasoning_content, actual.reasoning_content); assert_equals(expected.reasoning_content, actual.reasoning_content);
assert_equals(expected.tool_plan, actual.tool_plan);
assert_equals(expected.tool_calls.size(), actual.tool_calls.size()); assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
for (size_t i = 0; i < expected.tool_calls.size(); i++) { for (size_t i = 0; i < expected.tool_calls.size(); i++) {
const auto & expected_tool_call = expected.tool_calls[i]; const auto & expected_tool_call = expected.tool_calls[i];
@ -181,13 +180,15 @@ struct delta_data {
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
const json & user_message, const json & delta_message, const json & tools, const json & user_message, const json & delta_message, const json & tools,
const json & tool_choice) { const json & tool_choice,
bool think = false) {
common_chat_inputs inputs; common_chat_inputs inputs;
inputs.parallel_tool_calls = true; inputs.parallel_tool_calls = true;
inputs.messages = json::array(); inputs.messages = json::array();
inputs.messages.push_back(user_message); inputs.messages.push_back(user_message);
inputs.tools = tools; inputs.tools = tools;
inputs.tool_choice = tool_choice; inputs.tool_choice = tool_choice;
inputs.think = think;
auto params_prefix = common_chat_params_init(tmpl, inputs); auto params_prefix = common_chat_params_init(tmpl, inputs);
inputs.messages.push_back(delta_message); inputs.messages.push_back(delta_message);
@ -229,7 +230,8 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
const json & test_message, const json & tools = {}, const std::string & expected_delta = "", const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
bool expect_grammar_triggered = true, bool expect_grammar_triggered = true,
bool test_grammar_if_triggered = true) { bool test_grammar_if_triggered = true,
bool think = false) {
common_chat_msg expected_msg = msg_from_json(test_message); common_chat_msg expected_msg = msg_from_json(test_message);
auto user_message = json{ auto user_message = json{
@ -238,7 +240,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
}; };
for (const auto & tool_choice : json({ "auto", "required" })) { for (const auto & tool_choice : json({ "auto", "required" })) {
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice); auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice, think);
if (!expected_delta.empty()) { if (!expected_delta.empty()) {
assert_equals(expected_delta, data.delta); assert_equals(expected_delta, data.delta);
} }
@ -297,10 +299,14 @@ static void test_template_output_parsers() {
{ "role", "assistant" }, { "role", "assistant" },
{ "content", "Hello, world!\nWhat's up?" }, { "content", "Hello, world!\nWhat's up?" },
}; };
json message_assist_thoughts_unparsed { json message_assist_thoughts_unparsed_think {
{ "role", "assistant" }, { "role", "assistant" },
{ "content", "<think>I'm thinking</think>Hello, world!\nWhat's up?" }, { "content", "<think>I'm thinking</think>Hello, world!\nWhat's up?" },
}; };
json message_assist_thoughts_unparsed_r7b {
{ "role", "assistant" },
{ "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" },
};
json message_assist_thoughts { json message_assist_thoughts {
{ "role", "assistant" }, { "role", "assistant" },
{ "content", "Hello, world!\nWhat's up?" }, { "content", "Hello, world!\nWhat's up?" },
@ -371,7 +377,6 @@ static void test_template_output_parsers() {
json message_assist_call_idx { json message_assist_call_idx {
{ "role", "assistant"}, { "role", "assistant"},
{ "content", {}}, { "content", {}},
{ "tool_plan", "I'm not so sure"},
{ "tool_calls", { { "tool_calls", {
{ {
{ "type", "function" }, { "type", "function" },
@ -387,6 +392,8 @@ static void test_template_output_parsers() {
{ "content", {} }, { "content", {} },
{ "tool_calls", tool_calls } { "tool_calls", tool_calls }
}; };
json message_assist_call_tool_plan_idx = message_assist_call_idx;
message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking";
auto python_message_assist_call = json{ auto python_message_assist_call = json{
{ "role", "assistant" }, { "role", "assistant" },
@ -448,14 +455,52 @@ static void test_template_output_parsers() {
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>"); const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" }; std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format); assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format); assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_THINK, common_chat_params_init(tmpl, inputs_tools_think).format);
assert_msg_equals(msg_from_json(message_assist),
common_chat_parse(
"Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_COMMAND_R7B));
assert_msg_equals(msg_from_json(message_assist),
common_chat_parse(
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B));
assert_msg_equals(msg_from_json(message_assist),
common_chat_parse(
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B));
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
common_chat_parse(
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B));
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
common_chat_parse(
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B));
assert_msg_equals(msg_from_json(message_assist_thoughts),
common_chat_parse(
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
COMMON_CHAT_FORMAT_COMMAND_R7B_THINK));
test_template(tmpl, end_tokens, message_assist_call_idx, tools, test_template(tmpl, end_tokens, message_assist_call_idx, tools,
"<|START_THINKING|>I'm not so sure<|END_THINKING|>" "<|START_THINKING|><|END_THINKING|>"
"<|START_ACTION|>[\n" "<|START_ACTION|>[\n"
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
"]<|END_ACTION|>"); "]<|END_ACTION|>");
test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools,
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
"<|START_ACTION|>[\n"
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
"]<|END_ACTION|>",
/* expect_grammar_triggered= */ true,
/* test_grammar_if_triggered= */ true,
/* think= */ true);
test_template(tmpl, end_tokens, message_assist, tools, test_template(tmpl, end_tokens, message_assist, tools,
"<|START_RESPONSE|>Hello, world!\n" "<|START_RESPONSE|>Hello, world!\n"
"What's up?<|END_RESPONSE|>", "What's up?<|END_RESPONSE|>",
@ -616,12 +661,17 @@ static void test_template_output_parsers() {
"<s>", "</s>"); "<s>", "</s>");
std::vector<std::string> end_tokens{ "<end▁of▁sentence>" }; std::vector<std::string> end_tokens{ "<end▁of▁sentence>" };
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK, common_chat_params_init(tmpl, inputs_tools_think).format);
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1)); assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
assert_msg_equals(msg_from_json(message_assist_thoughts), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK)); common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
assert_msg_equals(msg_from_json(message_assist_thoughts),
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
// test_template(tmpl, end_tokens, message_assist_call, tools, // test_template(tmpl, end_tokens, message_assist_call, tools,
// "<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n" // "<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
// "```json\n" // "```json\n"
@ -637,12 +687,17 @@ static void test_template_output_parsers() {
"<s>", "</s>"); "<s>", "</s>");
std::vector<std::string> end_tokens{ "<end▁of▁sentence>" }; std::vector<std::string> end_tokens{ "<end▁of▁sentence>" };
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK, common_chat_params_init(tmpl, inputs_tools_think).format);
test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1)); assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
assert_msg_equals(msg_from_json(message_assist_thoughts), common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK)); common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
assert_msg_equals(msg_from_json(message_assist_thoughts),
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK));
assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed), assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed),
common_chat_parse( common_chat_parse(