Update test_chat_handlers.py

This commit is contained in:
ochafik 2024-03-30 01:10:14 +00:00
parent d8a53eadf2
commit ad2f4c119a

View file

@ -3,6 +3,7 @@
# python -m examples.openai.test_chat_handlers | tee examples/openai/test_chat_handlers.md # python -m examples.openai.test_chat_handlers | tee examples/openai/test_chat_handlers.md
import json import json
import sys
from examples.openai.api import FunctionCall, Message, Tool, ToolCall, ToolFunction from examples.openai.api import FunctionCall, Message, Tool, ToolCall, ToolFunction
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler
@ -143,6 +144,7 @@ TEST_TEMPLATES = {
"bos_token": "<s>" "bos_token": "<s>"
}, },
} }
MODELS_WITH_PARALLEL_CALLS = set(["functionary_v2_2"])
TEST_TEMPLATES = {k: ChatTemplate(**v) for k, v in TEST_TEMPLATES.items()} TEST_TEMPLATES = {k: ChatTemplate(**v) for k, v in TEST_TEMPLATES.items()}
if __name__ == "__main__": if __name__ == "__main__":
@ -151,7 +153,17 @@ if __name__ == "__main__":
print(f'\nMessages:\n\n```js\n{json.dumps([m.model_dump() for m in TEST_MESSAGES], indent=2)}\n```\n') print(f'\nMessages:\n\n```js\n{json.dumps([m.model_dump() for m in TEST_MESSAGES], indent=2)}\n```\n')
def check(b: bool, msg: str):
if not b:
sys.stderr.write(f'FAILURE: {msg}\n\n')
failures.append(msg)
functionary_v2_2 = TEST_TEMPLATES["functionary_v2_2"]
check(functionary_v2_2.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2, "functionary_v2_2 should be inferred as TYPESCRIPT_FUNCTIONARY_V2")
for model_name, chat_template in TEST_TEMPLATES.items(): for model_name, chat_template in TEST_TEMPLATES.items():
check(chat_template.potentially_supports_parallel_calls == (model_name in MODELS_WITH_PARALLEL_CALLS),
f"{model_name} should {'not ' if model_name not in MODELS_WITH_PARALLEL_CALLS else ''} be detected as potentially supporting parallel calls")
# if model_name == 'hermes_2_pro_mistral': # if model_name == 'hermes_2_pro_mistral':
# print("Skipping hermes_2_pro_mistral") # print("Skipping hermes_2_pro_mistral")
@ -159,8 +171,7 @@ if __name__ == "__main__":
def check_finds(msgs, strings_to_find): def check_finds(msgs, strings_to_find):
prompt = chat_template.render(msgs, add_generation_prompt=True) prompt = chat_template.render(msgs, add_generation_prompt=True)
for s in strings_to_find: for s in strings_to_find:
if str(s) not in prompt: check(str(s) in prompt, f"Missing {s} in prompt for {model_name}:\n{prompt}")
failures.append(f"Missing {s} in prompt for {model_name}:\n{prompt}")
check_finds([PROMPT_MESSAGE], (QUESTION,)) check_finds([PROMPT_MESSAGE], (QUESTION,))
check_finds([ASSIST_MESSAGE], (ANSWER,)) check_finds([ASSIST_MESSAGE], (ANSWER,))