Update test_chat_handlers.py

This commit is contained in:
ochafik 2024-03-30 01:10:14 +00:00
parent d8a53eadf2
commit ad2f4c119a

View file

@ -1,8 +1,9 @@
#
#
# python -m examples.openai.test_chat_handlers | tee examples/openai/test_chat_handlers.md
import json
import sys
from examples.openai.api import FunctionCall, Message, Tool, ToolCall, ToolFunction
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler
@ -143,15 +144,26 @@ TEST_TEMPLATES = {
"bos_token": "<s>"
},
}
MODELS_WITH_PARALLEL_CALLS = set(["functionary_v2_2"])
TEST_TEMPLATES = {k: ChatTemplate(**v) for k, v in TEST_TEMPLATES.items()}
if __name__ == "__main__":
failures = []
print(f'\nMessages:\n\n```js\n{json.dumps([m.model_dump() for m in TEST_MESSAGES], indent=2)}\n```\n')
def check(b: bool, msg: str):
if not b:
sys.stderr.write(f'FAILURE: {msg}\n\n')
failures.append(msg)
functionary_v2_2 = TEST_TEMPLATES["functionary_v2_2"]
check(functionary_v2_2.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2, "functionary_v2_2 should be inferred as TYPESCRIPT_FUNCTIONARY_V2")
for model_name, chat_template in TEST_TEMPLATES.items():
check(chat_template.potentially_supports_parallel_calls == (model_name in MODELS_WITH_PARALLEL_CALLS),
f"{model_name} should {'not ' if model_name not in MODELS_WITH_PARALLEL_CALLS else ''} be detected as potentially supporting parallel calls")
# if model_name == 'hermes_2_pro_mistral':
# print("Skipping hermes_2_pro_mistral")
@ -159,8 +171,7 @@ if __name__ == "__main__":
def check_finds(msgs, strings_to_find):
prompt = chat_template.render(msgs, add_generation_prompt=True)
for s in strings_to_find:
if str(s) not in prompt:
failures.append(f"Missing {s} in prompt for {model_name}:\n{prompt}")
check(str(s) in prompt, f"Missing {s} in prompt for {model_name}:\n{prompt}")
check_finds([PROMPT_MESSAGE], (QUESTION,))
check_finds([ASSIST_MESSAGE], (ANSWER,))
@ -187,7 +198,7 @@ if __name__ == "__main__":
tools=[],
),
}
for style in ToolsPromptStyle:
if (style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2) != (model_name.startswith("functionary")):
continue
@ -202,9 +213,9 @@ if __name__ == "__main__":
for tn, args in argss.items():
ch = get_chat_handler(args, parallel_calls=True, tool_style=style)
print(f'\n### {tn}\n')
print(f'\nPrompt:\n\n```json\n{ch.output_format_prompt.content}\n```\n')
print(f'\nGrammar:\n\n```js\n{ch.grammar}\n```\n')