From ad2f4c119a59cc5187643f3996b3b2d678ca4d36 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 30 Mar 2024 01:10:14 +0000 Subject: [PATCH] Update test_chat_handlers.py --- examples/openai/test_chat_handlers.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/examples/openai/test_chat_handlers.py b/examples/openai/test_chat_handlers.py index 7b44d29ca..33173ef06 100644 --- a/examples/openai/test_chat_handlers.py +++ b/examples/openai/test_chat_handlers.py @@ -1,8 +1,9 @@ # # # python -m examples.openai.test_chat_handlers | tee examples/openai/test_chat_handlers.md - + import json +import sys from examples.openai.api import FunctionCall, Message, Tool, ToolCall, ToolFunction from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler @@ -143,15 +144,26 @@ TEST_TEMPLATES = { "bos_token": "" }, } +MODELS_WITH_PARALLEL_CALLS = set(["functionary_v2_2"]) TEST_TEMPLATES = {k: ChatTemplate(**v) for k, v in TEST_TEMPLATES.items()} if __name__ == "__main__": - + failures = [] print(f'\nMessages:\n\n```js\n{json.dumps([m.model_dump() for m in TEST_MESSAGES], indent=2)}\n```\n') + def check(b: bool, msg: str): + if not b: + sys.stderr.write(f'FAILURE: {msg}\n\n') + failures.append(msg) + + functionary_v2_2 = TEST_TEMPLATES["functionary_v2_2"] + check(functionary_v2_2.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2, "functionary_v2_2 should be inferred as TYPESCRIPT_FUNCTIONARY_V2") + for model_name, chat_template in TEST_TEMPLATES.items(): + check(chat_template.potentially_supports_parallel_calls == (model_name in MODELS_WITH_PARALLEL_CALLS), + f"{model_name} should {'not ' if model_name not in MODELS_WITH_PARALLEL_CALLS else ''} be detected as potentially supporting parallel calls") # if model_name == 'hermes_2_pro_mistral': # print("Skipping hermes_2_pro_mistral") @@ -159,8 +171,7 @@ if __name__ == "__main__": def check_finds(msgs, strings_to_find): prompt = chat_template.render(msgs, add_generation_prompt=True) for s in strings_to_find: - if str(s) not in prompt: - failures.append(f"Missing {s} in prompt for {model_name}:\n{prompt}") + check(str(s) in prompt, f"Missing {s} in prompt for {model_name}:\n{prompt}") check_finds([PROMPT_MESSAGE], (QUESTION,)) check_finds([ASSIST_MESSAGE], (ANSWER,)) @@ -187,7 +198,7 @@ if __name__ == "__main__": tools=[], ), } - + for style in ToolsPromptStyle: if (style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2) != (model_name.startswith("functionary")): continue @@ -202,9 +213,9 @@ if __name__ == "__main__": for tn, args in argss.items(): ch = get_chat_handler(args, parallel_calls=True, tool_style=style) - + print(f'\n### {tn}\n') - + print(f'\nPrompt:\n\n```json\n{ch.output_format_prompt.content}\n```\n') print(f'\nGrammar:\n\n```js\n{ch.grammar}\n```\n')