Update test_chat_handlers.py
This commit is contained in:
parent
d8a53eadf2
commit
ad2f4c119a
1 changed files with 18 additions and 7 deletions
|
@ -3,6 +3,7 @@
|
||||||
# python -m examples.openai.test_chat_handlers | tee examples/openai/test_chat_handlers.md
|
# python -m examples.openai.test_chat_handlers | tee examples/openai/test_chat_handlers.md
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
from examples.openai.api import FunctionCall, Message, Tool, ToolCall, ToolFunction
|
from examples.openai.api import FunctionCall, Message, Tool, ToolCall, ToolFunction
|
||||||
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler
|
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler
|
||||||
|
@ -143,6 +144,7 @@ TEST_TEMPLATES = {
|
||||||
"bos_token": "<s>"
|
"bos_token": "<s>"
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
MODELS_WITH_PARALLEL_CALLS = set(["functionary_v2_2"])
|
||||||
TEST_TEMPLATES = {k: ChatTemplate(**v) for k, v in TEST_TEMPLATES.items()}
|
TEST_TEMPLATES = {k: ChatTemplate(**v) for k, v in TEST_TEMPLATES.items()}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -151,7 +153,17 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
print(f'\nMessages:\n\n```js\n{json.dumps([m.model_dump() for m in TEST_MESSAGES], indent=2)}\n```\n')
|
print(f'\nMessages:\n\n```js\n{json.dumps([m.model_dump() for m in TEST_MESSAGES], indent=2)}\n```\n')
|
||||||
|
|
||||||
|
def check(b: bool, msg: str):
|
||||||
|
if not b:
|
||||||
|
sys.stderr.write(f'FAILURE: {msg}\n\n')
|
||||||
|
failures.append(msg)
|
||||||
|
|
||||||
|
functionary_v2_2 = TEST_TEMPLATES["functionary_v2_2"]
|
||||||
|
check(functionary_v2_2.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2, "functionary_v2_2 should be inferred as TYPESCRIPT_FUNCTIONARY_V2")
|
||||||
|
|
||||||
for model_name, chat_template in TEST_TEMPLATES.items():
|
for model_name, chat_template in TEST_TEMPLATES.items():
|
||||||
|
check(chat_template.potentially_supports_parallel_calls == (model_name in MODELS_WITH_PARALLEL_CALLS),
|
||||||
|
f"{model_name} should {'not ' if model_name not in MODELS_WITH_PARALLEL_CALLS else ''} be detected as potentially supporting parallel calls")
|
||||||
|
|
||||||
# if model_name == 'hermes_2_pro_mistral':
|
# if model_name == 'hermes_2_pro_mistral':
|
||||||
# print("Skipping hermes_2_pro_mistral")
|
# print("Skipping hermes_2_pro_mistral")
|
||||||
|
@ -159,8 +171,7 @@ if __name__ == "__main__":
|
||||||
def check_finds(msgs, strings_to_find):
|
def check_finds(msgs, strings_to_find):
|
||||||
prompt = chat_template.render(msgs, add_generation_prompt=True)
|
prompt = chat_template.render(msgs, add_generation_prompt=True)
|
||||||
for s in strings_to_find:
|
for s in strings_to_find:
|
||||||
if str(s) not in prompt:
|
check(str(s) in prompt, f"Missing {s} in prompt for {model_name}:\n{prompt}")
|
||||||
failures.append(f"Missing {s} in prompt for {model_name}:\n{prompt}")
|
|
||||||
|
|
||||||
check_finds([PROMPT_MESSAGE], (QUESTION,))
|
check_finds([PROMPT_MESSAGE], (QUESTION,))
|
||||||
check_finds([ASSIST_MESSAGE], (ANSWER,))
|
check_finds([ASSIST_MESSAGE], (ANSWER,))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue