ensure deepseek r1 thoughts parsed even w/o tool calls
This commit is contained in:
parent
b6e14a4101
commit
1f5ec59809
2 changed files with 91 additions and 49 deletions
|
@ -565,6 +565,7 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
|
@ -598,6 +599,7 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
||||||
"<|tool▁call▁end|>",
|
"<|tool▁call▁end|>",
|
||||||
};
|
};
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
}
|
||||||
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
|
|
||||||
// Hacks to fix the official (broken) prompt.
|
// Hacks to fix the official (broken) prompt.
|
||||||
|
@ -638,7 +640,7 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input)
|
||||||
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
|
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
|
||||||
msg.tool_calls = std::move(msg2.tool_calls);
|
msg.tool_calls = std::move(msg2.tool_calls);
|
||||||
} else {
|
} else {
|
||||||
msg.content = rest;
|
msg.content = std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
msg.content = input;
|
msg.content = input;
|
||||||
|
@ -970,6 +972,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
|
||||||
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
|
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
|
||||||
return common_chat_params_init_firefunction_v2(tmpl, inputs);
|
return common_chat_params_init_firefunction_v2(tmpl, inputs);
|
||||||
}
|
}
|
||||||
|
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
||||||
|
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
||||||
|
}
|
||||||
|
|
||||||
if (!has_tools) {
|
if (!has_tools) {
|
||||||
return common_chat_params_init_without_tools(tmpl, inputs);
|
return common_chat_params_init_without_tools(tmpl, inputs);
|
||||||
|
@ -986,9 +991,6 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
|
||||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
||||||
}
|
}
|
||||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
|
||||||
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
|
||||||
}
|
|
||||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||||
return common_chat_params_init_mistral_nemo(tmpl, inputs);
|
return common_chat_params_init_mistral_nemo(tmpl, inputs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -358,7 +358,7 @@ def test_weather_tool_call(hf_repo: str, template_override: str | Tuple[str, str
|
||||||
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
|
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
|
||||||
("[\\s\\S\\r\\n]*?\\b0\\.55644242476$", 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
("[\\s\\S\\r\\n]*?\\b0\\.55644242476$", 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||||
("[\\s\\S\\r\\n]*?which equals 0\\.5\\.", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
("[\\s\\S\\r\\n]*?which equals 0\\.5\\.", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
("**Answer:** 0\\.25\\b", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
("[\\s\\S\\r\\n]*?\\*\\*Answer:\\*\\* 0\\.25\\b", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||||
])
|
])
|
||||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||||
global server
|
global server
|
||||||
|
@ -435,6 +435,46 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
||||||
f'Expected something like "The y coordinate is 0.56.", got {content}'
|
f'Expected something like "The y coordinate is 0.56.", got {content}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("n_predict,expect_content,expect_thoughts,hf_repo,template_override", [
|
||||||
|
(128, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||||
|
(1024, "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
|
(1024, "To find the sum of.*", "First, I need to add the tens place.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||||
|
])
|
||||||
|
def test_thoughts(n_predict: int, expect_content: str | None, expect_thoughts: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||||
|
global server
|
||||||
|
server.n_slots = 1
|
||||||
|
server.jinja = True
|
||||||
|
server.n_ctx = 8192 * 2
|
||||||
|
server.n_predict = n_predict
|
||||||
|
server.model_hf_repo = hf_repo
|
||||||
|
server.model_hf_file = None
|
||||||
|
if isinstance(template_override, tuple):
|
||||||
|
(template_hf_repo, template_variant) = template_override
|
||||||
|
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||||
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
|
elif isinstance(template_override, str):
|
||||||
|
server.chat_template = template_override
|
||||||
|
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||||
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
"max_tokens": n_predict,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the sum of 102 and 7?"},
|
||||||
|
]
|
||||||
|
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||||
|
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||||
|
choice = res.body["choices"][0]
|
||||||
|
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||||
|
|
||||||
|
content = choice["message"].get("content")
|
||||||
|
if expect_content is not None:
|
||||||
|
assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
|
||||||
|
|
||||||
|
thoughts = choice["message"].get("thoughts")
|
||||||
|
if expect_thoughts is not None:
|
||||||
|
assert re.match(expect_thoughts, thoughts), f'Expected {expect_thoughts}, got {thoughts}'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
|
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
|
||||||
(None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
(None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue