--think to force any model to return reasoning_content (or just parse <think> for deepseek r1)

This commit is contained in:
ochafik 2025-02-05 12:16:37 +00:00
parent 5d60cebbcc
commit 9d7c3cc51b
9 changed files with 306 additions and 145 deletions

View file

@ -4052,7 +4052,7 @@ int main(int argc, char ** argv) {
}
auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.think, ctx_server.chat_templates);
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
@ -4065,7 +4065,7 @@ int main(int argc, char ** argv) {
// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.think, ctx_server.chat_templates);
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};

View file

@ -439,14 +439,20 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
@pytest.mark.slow
@pytest.mark.parametrize("n_predict,expect_content,expect_reasoning_content,hf_repo,template_override", [
(128, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, "To find the sum of.*", "First, I need to add the tens place.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
@pytest.mark.parametrize("n_predict,think,expect_content,expect_reasoning_content,hf_repo,template_override", [
(1024, True, "^The sum of 102 and 7 is 109.*", "^The user's request is straightforward.*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(128, False, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, True, "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, False, "<think>\nI need[\\s\\S\\r\\n]*</think>\nTo find", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, True, "To find the sum of.*", "First, I need to add the tens place.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
(1024, False, "<think>\nI need[\\s\\S\\r\\n]*</think>To find", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
])
def test_reasoning_content(n_predict: int, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
def test_thoughts(n_predict: int, think: bool, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
server.n_slots = 1
server.think = think
server.jinja = True
server.n_ctx = 8192 * 2
server.n_predict = n_predict
@ -470,11 +476,15 @@ def test_reasoning_content(n_predict: int, expect_content: str | None, expect_re
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content")
if expect_content is not None:
if expect_content is None:
assert content is None, f'Expected no content in {choice["message"]}'
else:
assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
reasoning_content = choice["message"].get("reasoning_content")
if expect_reasoning_content is not None:
if expect_reasoning_content is None:
assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}'
else:
assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}'

View file

@ -78,6 +78,7 @@ class ServerProcess:
draft_max: int | None = None
no_webui: bool | None = None
jinja: bool | None = None
think: bool | None = None
chat_template: str | None = None
chat_template_file: str | None = None
@ -172,6 +173,8 @@ class ServerProcess:
server_args.append("--no-webui")
if self.jinja:
server_args.append("--jinja")
if self.think:
server_args.append("--think")
if self.chat_template:
server_args.extend(["--chat-template", self.chat_template])
if self.chat_template_file:

View file

@ -578,6 +578,7 @@ static json oaicompat_completion_params_parse(const json & body) {
static json oaicompat_completion_params_parse(
const json & body, /* openai api json semantics */
bool use_jinja,
bool think,
const common_chat_templates & chat_templates)
{
json llama_params;
@ -633,9 +634,10 @@ static json oaicompat_completion_params_parse(
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
common_chat_inputs inputs;
inputs.messages = body.at("messages");
inputs.tools = tools;
inputs.tool_choice = tool_choice;
inputs.think = think;
inputs.messages = body.at("messages");
inputs.tools = tools;
inputs.tool_choice = tool_choice;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");