Update test_chat_handlers.py
This commit is contained in:
parent
d8a53eadf2
commit
ad2f4c119a
1 changed files with 18 additions and 7 deletions
|
@ -1,8 +1,9 @@
|
|||
#
|
||||
#
|
||||
# python -m examples.openai.test_chat_handlers | tee examples/openai/test_chat_handlers.md
|
||||
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from examples.openai.api import FunctionCall, Message, Tool, ToolCall, ToolFunction
|
||||
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler
|
||||
|
@ -143,15 +144,26 @@ TEST_TEMPLATES = {
|
|||
"bos_token": "<s>"
|
||||
},
|
||||
}
|
||||
MODELS_WITH_PARALLEL_CALLS = set(["functionary_v2_2"])
|
||||
TEST_TEMPLATES = {k: ChatTemplate(**v) for k, v in TEST_TEMPLATES.items()}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
failures = []
|
||||
|
||||
print(f'\nMessages:\n\n```js\n{json.dumps([m.model_dump() for m in TEST_MESSAGES], indent=2)}\n```\n')
|
||||
|
||||
def check(b: bool, msg: str):
|
||||
if not b:
|
||||
sys.stderr.write(f'FAILURE: {msg}\n\n')
|
||||
failures.append(msg)
|
||||
|
||||
functionary_v2_2 = TEST_TEMPLATES["functionary_v2_2"]
|
||||
check(functionary_v2_2.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2, "functionary_v2_2 should be inferred as TYPESCRIPT_FUNCTIONARY_V2")
|
||||
|
||||
for model_name, chat_template in TEST_TEMPLATES.items():
|
||||
check(chat_template.potentially_supports_parallel_calls == (model_name in MODELS_WITH_PARALLEL_CALLS),
|
||||
f"{model_name} should {'not ' if model_name not in MODELS_WITH_PARALLEL_CALLS else ''} be detected as potentially supporting parallel calls")
|
||||
|
||||
# if model_name == 'hermes_2_pro_mistral':
|
||||
# print("Skipping hermes_2_pro_mistral")
|
||||
|
@ -159,8 +171,7 @@ if __name__ == "__main__":
|
|||
def check_finds(msgs, strings_to_find):
|
||||
prompt = chat_template.render(msgs, add_generation_prompt=True)
|
||||
for s in strings_to_find:
|
||||
if str(s) not in prompt:
|
||||
failures.append(f"Missing {s} in prompt for {model_name}:\n{prompt}")
|
||||
check(str(s) in prompt, f"Missing {s} in prompt for {model_name}:\n{prompt}")
|
||||
|
||||
check_finds([PROMPT_MESSAGE], (QUESTION,))
|
||||
check_finds([ASSIST_MESSAGE], (ANSWER,))
|
||||
|
@ -187,7 +198,7 @@ if __name__ == "__main__":
|
|||
tools=[],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
for style in ToolsPromptStyle:
|
||||
if (style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2) != (model_name.startswith("functionary")):
|
||||
continue
|
||||
|
@ -202,9 +213,9 @@ if __name__ == "__main__":
|
|||
|
||||
for tn, args in argss.items():
|
||||
ch = get_chat_handler(args, parallel_calls=True, tool_style=style)
|
||||
|
||||
|
||||
print(f'\n### {tn}\n')
|
||||
|
||||
|
||||
print(f'\nPrompt:\n\n```json\n{ch.output_format_prompt.content}\n```\n')
|
||||
|
||||
print(f'\nGrammar:\n\n```js\n{ch.grammar}\n```\n')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue