ensure deepseek r1 thoughts parsed even w/o tool calls
This commit is contained in:
		
							parent
							
								
									b6e14a4101
								
							
						
					
					
						commit
						1f5ec59809
					
				
					 2 changed files with 91 additions and 49 deletions
				
			
		|  | @ -345,20 +345,20 @@ def test_weather_tool_call(hf_repo: str, template_override: str | Tuple[str, str | |||
| 
 | ||||
| @pytest.mark.slow | ||||
| @pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [ | ||||
|     (None,                                     128,  "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      "chatml"), | ||||
|     (None,                                     128,  "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        None), | ||||
|     (None,                                     128,  "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        "chatml"), | ||||
|     (None,                                     128,  "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M",    ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), | ||||
|     (None,                                     128,  "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",      ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), | ||||
|     (None,                                     128,  "bartowski/functionary-small-v3.2-GGUF:Q8_0",       ("meetkai/functionary-medium-v3.2", None)), | ||||
|     (None,                                     128,  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), | ||||
|     (None,                                     128,  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), | ||||
|     ("^> 0.56$",                               128,  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), | ||||
|     (None,                                             128,  "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      "chatml"), | ||||
|     (None,                                             128,  "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        None), | ||||
|     (None,                                             128,  "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M",        "chatml"), | ||||
|     (None,                                             128,  "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M",    ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), | ||||
|     (None,                                             128,  "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",      ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), | ||||
|     (None,                                             128,  "bartowski/functionary-small-v3.2-GGUF:Q8_0",       ("meetkai/functionary-medium-v3.2", None)), | ||||
|     (None,                                             128,  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), | ||||
|     (None,                                             128,  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), | ||||
|     ("^> 0.56$",                                       128,  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), | ||||
| 
 | ||||
|     # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value) | ||||
|     ("[\\s\\S\\r\\n]*?\\b0\\.55644242476$",    128,  "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      None), | ||||
|     ("[\\s\\S\\r\\n]*?which equals 0\\.5\\.", 8192,  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), | ||||
|     ("**Answer:** 0\\.25\\b",                 8192,  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), | ||||
|     ("[\\s\\S\\r\\n]*?\\b0\\.55644242476$",            128,  "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",      None), | ||||
|     ("[\\s\\S\\r\\n]*?which equals 0\\.5\\.",          8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), | ||||
|     ("[\\s\\S\\r\\n]*?\\*\\*Answer:\\*\\* 0\\.25\\b",  8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), | ||||
| ]) | ||||
| def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None): | ||||
|     global server | ||||
|  | @ -435,6 +435,46 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, | |||
|             f'Expected something like "The y coordinate is 0.56.", got {content}' | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.slow | ||||
| @pytest.mark.parametrize("n_predict,expect_content,expect_thoughts,hf_repo,template_override", [ | ||||
|     (128, "^The sum of 102 and 7 is 109.*",  None,                                          "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None), | ||||
|     (1024, "To find the sum of.*",           "I need to calculate the sum of 102 and 7.*",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), | ||||
|     (1024, "To find the sum of.*",           "First, I need to add the tens place.*",       "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), | ||||
| ]) | ||||
| def test_thoughts(n_predict: int, expect_content: str | None, expect_thoughts: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): | ||||
|     global server | ||||
|     server.n_slots = 1 | ||||
|     server.jinja = True | ||||
|     server.n_ctx = 8192 * 2 | ||||
|     server.n_predict = n_predict | ||||
|     server.model_hf_repo = hf_repo | ||||
|     server.model_hf_file = None | ||||
|     if isinstance(template_override, tuple): | ||||
|         (template_hf_repo, template_variant) = template_override | ||||
|         server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" | ||||
|         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." | ||||
|     elif isinstance(template_override, str): | ||||
|         server.chat_template = template_override | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     res = server.make_request("POST", "/chat/completions", data={ | ||||
|         "max_tokens": n_predict, | ||||
|         "messages": [ | ||||
|             {"role": "user", "content": "What's the sum of 102 and 7?"}, | ||||
|         ] | ||||
|     }, timeout=TIMEOUT_HTTP_REQUEST) | ||||
|     assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" | ||||
|     choice = res.body["choices"][0] | ||||
|     assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' | ||||
| 
 | ||||
|     content = choice["message"].get("content") | ||||
|     if expect_content is not None: | ||||
|         assert re.match(expect_content, content), f'Expected {expect_content}, got {content}' | ||||
| 
 | ||||
|     thoughts = choice["message"].get("thoughts") | ||||
|     if expect_thoughts is not None: | ||||
|         assert re.match(expect_thoughts, thoughts), f'Expected {expect_thoughts}, got {thoughts}' | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.slow | ||||
| @pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ | ||||
|     (None,                 "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue