openai: refactor chat handler vs. template
This commit is contained in:
parent
3c3eff52aa
commit
6935503b53
4 changed files with 1133 additions and 567 deletions
|
@ -88,12 +88,9 @@ class ChatTemplate(BaseModel):
|
|||
|
||||
def probe_template_capabilities(self):
|
||||
|
||||
def test(messages: list[Message]):
|
||||
return self._template.render(messages=messages, eos_token=self.eos_token, bos_token=self.bos_token, raise_exception=raise_exception, add_generation_prompt=True)
|
||||
|
||||
def succeeds(messages: list[Message], strings_to_find = ()):
|
||||
try:
|
||||
result = test(messages)
|
||||
result = self.raw_render(messages, add_generation_prompt=True)
|
||||
# print(result)
|
||||
for s in strings_to_find:
|
||||
if s not in result:
|
||||
|
@ -133,8 +130,8 @@ class ChatTemplate(BaseModel):
|
|||
|
||||
delimiter = '<%$[SAMPLE]$%>'
|
||||
user_msg = Message(role="user", content="Hey")
|
||||
empty_prompt = self.render([user_msg], add_generation_prompt=True).strip()
|
||||
planted_prompt = self.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
|
||||
empty_prompt = self.raw_render([user_msg], add_generation_prompt=True).strip()
|
||||
planted_prompt = self.raw_render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
|
||||
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
||||
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
||||
|
||||
|
@ -181,10 +178,59 @@ class ChatTemplate(BaseModel):
|
|||
bos_token = tokenizer.bos_token,
|
||||
eos_token = tokenizer.eos_token)
|
||||
|
||||
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
||||
def raw_render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
||||
result = self._template.render(
|
||||
messages=messages,
|
||||
eos_token=self.eos_token,
|
||||
bos_token='' if omit_bos else self.bos_token,
|
||||
raise_exception=raise_exception,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
return result
|
||||
|
||||
class ChatHandlerArgs(BaseModel):
|
||||
chat_template: ChatTemplate
|
||||
response_schema: Optional[dict] = None
|
||||
tools: Optional[list[Tool]] = None
|
||||
|
||||
class ChatHandler(ABC):
|
||||
def __init__(self, args: ChatHandlerArgs, style: Optional[ToolsPromptStyle]):
|
||||
self.args = args
|
||||
self.style = style
|
||||
self.output_format_prompt: Optional[Message] = None
|
||||
self.grammar: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, s: str) -> Optional[Message]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def render_prompt(self, messages: list[Message]) -> str:
|
||||
def normalize(m: Message):
|
||||
if self.style == ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS:
|
||||
if m.tool_calls:
|
||||
m = Message(
|
||||
role=m.role,
|
||||
content=json.dumps({
|
||||
_THOUGHT_KEY: m.content or '',
|
||||
"next_step": {
|
||||
"tool_calls": [tc.model_dump() for tc in m.tool_calls]
|
||||
}
|
||||
}, indent=2)
|
||||
)
|
||||
else:
|
||||
m = Message(
|
||||
role=m.role,
|
||||
content=json.dumps({
|
||||
_THOUGHT_KEY: '',
|
||||
"next_step": {
|
||||
"result": m.content
|
||||
}
|
||||
}, indent=2)
|
||||
)
|
||||
# Fall through to benefit from role normalization
|
||||
|
||||
if m.tool_calls:
|
||||
if not self.formats_tool_call or not self.formats_tool_call_content:
|
||||
if not self.args.chat_template.formats_tool_call or not self.args.chat_template.formats_tool_call_content:
|
||||
return Message(
|
||||
role=m.role,
|
||||
content='\n'.join([
|
||||
|
@ -195,7 +241,7 @@ class ChatTemplate(BaseModel):
|
|||
])
|
||||
])
|
||||
)
|
||||
elif self.expects_stringified_function_arguments:
|
||||
elif self.args.chat_template.expects_stringified_function_arguments:
|
||||
return Message(
|
||||
role=m.role,
|
||||
content=m.content,
|
||||
|
@ -215,7 +261,7 @@ class ChatTemplate(BaseModel):
|
|||
)
|
||||
else:
|
||||
return m
|
||||
elif self.expects_strict_user_assistant_alternance and m.role not in ('user', 'assistant'):
|
||||
elif self.args.chat_template.expects_strict_user_assistant_alternance and m.role not in ('user', 'assistant'):
|
||||
if m.role == "system":
|
||||
return Message(role="user", content=f'[SYS]{m.content}[/SYS]')
|
||||
elif m.role == "tool":
|
||||
|
@ -228,7 +274,7 @@ class ChatTemplate(BaseModel):
|
|||
|
||||
messages=[normalize(m) for m in messages]
|
||||
|
||||
if self.expects_strict_user_assistant_alternance:
|
||||
if self.args.chat_template.expects_strict_user_assistant_alternance:
|
||||
new_messages=[]
|
||||
current_role = 'user'
|
||||
current_content = []
|
||||
|
@ -237,7 +283,7 @@ class ChatTemplate(BaseModel):
|
|||
nonlocal current_content
|
||||
nonlocal current_role
|
||||
|
||||
if self.expects_strict_user_assistant_alternance or current_content:
|
||||
if self.args.chat_template.expects_strict_user_assistant_alternance or current_content:
|
||||
new_messages.append(Message(
|
||||
role=current_role,
|
||||
content='\n'.join(current_content)
|
||||
|
@ -263,7 +309,7 @@ class ChatTemplate(BaseModel):
|
|||
messages = [m.model_dump() for m in messages]
|
||||
|
||||
# if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
||||
if self.expects_stringified_function_arguments:
|
||||
if self.args.chat_template.expects_stringified_function_arguments:
|
||||
messages = [
|
||||
{
|
||||
**m,
|
||||
|
@ -281,33 +327,14 @@ class ChatTemplate(BaseModel):
|
|||
for m in messages
|
||||
]
|
||||
|
||||
result = self._template.render(
|
||||
return self.args.chat_template.raw_render(
|
||||
messages=messages,
|
||||
eos_token=self.eos_token,
|
||||
bos_token='' if omit_bos else self.bos_token,
|
||||
raise_exception=raise_exception,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return result
|
||||
|
||||
class ChatHandlerArgs(BaseModel):
|
||||
chat_template: ChatTemplate
|
||||
response_schema: Optional[dict] = None
|
||||
tools: Optional[list[Tool]] = None
|
||||
|
||||
class ChatHandler(ABC):
|
||||
def __init__(self, args: ChatHandlerArgs):
|
||||
self.args = args
|
||||
self.output_format_prompt: Optional[Message] = None
|
||||
self.grammar: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, s: str) -> Optional[Message]:
|
||||
raise NotImplementedError()
|
||||
|
||||
class NoToolsChatHandler(ChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs):
|
||||
super().__init__(args)
|
||||
super().__init__(args, None)
|
||||
assert not args.tools
|
||||
|
||||
if args.response_schema:
|
||||
|
@ -327,8 +354,8 @@ class NoToolsChatHandler(ChatHandler):
|
|||
return Message(role="assistant", content=s)
|
||||
|
||||
class ToolCallTagsChatHandler(ChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, parallel_calls: bool):
|
||||
super().__init__(args)
|
||||
def __init__(self, args: ChatHandlerArgs, style: Optional[ToolsPromptStyle], escapes_underscores: bool, parallel_calls: bool):
|
||||
super().__init__(args, style)
|
||||
|
||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||
tool_rules = []
|
||||
|
@ -404,8 +431,8 @@ class ToolCallTagsChatHandler(ChatHandler):
|
|||
|
||||
|
||||
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs, template: str, parallel_calls: bool, escapes_underscores: bool = False):
|
||||
super().__init__(args, escapes_underscores=escapes_underscores, parallel_calls=parallel_calls)
|
||||
def __init__(self, args: ChatHandlerArgs, template: str, parallel_calls: bool, escapes_underscores: bool = False, style: Optional[ToolsPromptStyle] = None):
|
||||
super().__init__(args, style=style, escapes_underscores=escapes_underscores, parallel_calls=parallel_calls)
|
||||
assert '{tools}' in template, 'Template must contain "{tools}"'
|
||||
|
||||
self.output_format_prompt = Message(
|
||||
|
@ -418,7 +445,7 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
|||
|
||||
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
||||
super().__init__(args, escapes_underscores=False, parallel_calls=parallel_calls)
|
||||
super().__init__(args, style=ToolsPromptStyle.TOOLS_HERMES_2_PRO, escapes_underscores=False, parallel_calls=parallel_calls)
|
||||
|
||||
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
|
||||
path = str(Path(__file__).parent / "hermes_function_calling")
|
||||
|
@ -434,7 +461,7 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
|||
|
||||
class FunctionaryToolsChatHandler(ChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
||||
super().__init__(args)
|
||||
super().__init__(args, ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2)
|
||||
|
||||
self.output_format_prompt = Message(
|
||||
role="system",
|
||||
|
@ -541,9 +568,9 @@ def _make_bespoke_schema(response_schema, tool_call_schema, parallel_calls):
|
|||
# "required": ["next_step"]
|
||||
}
|
||||
|
||||
class BespokeToolsChatHandler(ChatHandler):
|
||||
class ThoughtfulStepsToolsChatHandler(ChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
||||
super().__init__(args)
|
||||
super().__init__(args, ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS)
|
||||
|
||||
# args.response_schema = args.response_schema or {}
|
||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||
|
@ -660,7 +687,7 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
|
|||
return NoToolsChatHandler(args)
|
||||
|
||||
elif tool_style == ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS:
|
||||
return BespokeToolsChatHandler(args, parallel_calls=parallel_calls)
|
||||
return ThoughtfulStepsToolsChatHandler(args, parallel_calls=parallel_calls)
|
||||
|
||||
elif tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
||||
return FunctionaryToolsChatHandler(args, parallel_calls=parallel_calls)
|
||||
|
|
|
@ -140,7 +140,7 @@ def main(
|
|||
if chat_handler.output_format_prompt:
|
||||
messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
|
||||
|
||||
prompt = chat_template.render(messages, add_generation_prompt=True)
|
||||
prompt = chat_handler.render_prompt(messages)
|
||||
|
||||
if verbose:
|
||||
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -165,27 +165,6 @@ if __name__ == "__main__":
|
|||
check(chat_template.potentially_supports_parallel_calls == (model_name in MODELS_WITH_PARALLEL_CALLS),
|
||||
f"{model_name} should {'not ' if model_name not in MODELS_WITH_PARALLEL_CALLS else ''} be detected as potentially supporting parallel calls")
|
||||
|
||||
# if model_name == 'hermes_2_pro_mistral':
|
||||
# print("Skipping hermes_2_pro_mistral")
|
||||
# continue
|
||||
def check_finds(msgs, strings_to_find):
|
||||
prompt = chat_template.render(msgs, add_generation_prompt=True)
|
||||
for s in strings_to_find:
|
||||
check(str(s) in prompt, f"Missing {s} in prompt for {model_name}:\n{prompt}")
|
||||
|
||||
check_finds([PROMPT_MESSAGE], (QUESTION,))
|
||||
check_finds([ASSIST_MESSAGE], (ANSWER,))
|
||||
check_finds([TOOL_CALL_MESSAGE], (TEST_ARG_A, TEST_ARG_B, TOOL_NAME))
|
||||
check_finds([THOUGHTFUL_TOOL_CALL_MESSAGE], (TEST_THOUGHT, TEST_ARG_A, TEST_ARG_B, TOOL_NAME,))
|
||||
check_finds([TOOL_MESSAGE], (TEST_SUM,))
|
||||
if chat_template.potentially_supports_parallel_calls:
|
||||
check_finds([TOOL_MESSAGE], (TOOL_NAME,))
|
||||
|
||||
print(f"\n# {model_name}\n")
|
||||
print(f'\nTemplate:\n\n```js\n{chat_template.template}\n```\n')
|
||||
|
||||
print(f'\nPrompt:\n\n```js\n{chat_template.render(TEST_MESSAGES, add_generation_prompt=True)}\n```\n')
|
||||
|
||||
argss = {
|
||||
"with tools": ChatHandlerArgs(
|
||||
chat_template=chat_template, #ChatTemplate.from_gguf(GGUFKeyValues(model)),
|
||||
|
@ -199,6 +178,13 @@ if __name__ == "__main__":
|
|||
),
|
||||
}
|
||||
|
||||
print(f"\n# {model_name}\n")
|
||||
|
||||
if chat_template.potentially_supports_parallel_calls:
|
||||
print("\n**Might Support Parallel Tool Calls**\n")
|
||||
|
||||
print(f'\nTemplate:\n\n```js\n{chat_template.template}\n```\n')
|
||||
|
||||
for style in ToolsPromptStyle:
|
||||
if (style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2) != (model_name.startswith("functionary")):
|
||||
continue
|
||||
|
@ -209,17 +195,39 @@ if __name__ == "__main__":
|
|||
if model_name == "mistral_instruct_v0_1" and style not in (ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS, ToolsPromptStyle.TOOLS_MIXTRAL):
|
||||
continue
|
||||
|
||||
print(f'\n## {style}\n')
|
||||
print(f'\n## {model_name} / {style.name}\n')
|
||||
|
||||
for tn, args in argss.items():
|
||||
|
||||
for tool_situation, args in argss.items():
|
||||
ch = get_chat_handler(args, parallel_calls=True, tool_style=style)
|
||||
|
||||
print(f'\n### {tn}\n')
|
||||
print(f'\n### {model_name} / {style.name} / {tool_situation}\n')
|
||||
|
||||
print(f'\nPrompt:\n\n```js\n{ch.render_prompt(TEST_MESSAGES)}\n```\n')
|
||||
|
||||
print(f'\nPrompt:\n\n```json\n{ch.output_format_prompt.content}\n```\n')
|
||||
|
||||
print(f'\nGrammar:\n\n```js\n{ch.grammar}\n```\n')
|
||||
|
||||
|
||||
# if model_name == 'hermes_2_pro_mistral':
|
||||
# print("Skipping hermes_2_pro_mistral")
|
||||
# continue
|
||||
def check_finds(msgs, strings_to_find):
|
||||
prompt = ch.render_prompt(msgs)
|
||||
for s in strings_to_find:
|
||||
check(str(s) in prompt, f"Missing {s} in prompt for {model_name}:\n{prompt}")
|
||||
|
||||
check_finds([PROMPT_MESSAGE], (QUESTION,))
|
||||
check_finds([ASSIST_MESSAGE], (ANSWER,))
|
||||
check_finds([TOOL_CALL_MESSAGE], (TEST_ARG_A, TEST_ARG_B, TOOL_NAME))
|
||||
check_finds([THOUGHTFUL_TOOL_CALL_MESSAGE], (TEST_THOUGHT, TEST_ARG_A, TEST_ARG_B, TOOL_NAME,))
|
||||
check_finds([TOOL_MESSAGE], (TEST_SUM,))
|
||||
if chat_template.potentially_supports_parallel_calls:
|
||||
check_finds([TOOL_MESSAGE], (TOOL_NAME,))
|
||||
|
||||
|
||||
|
||||
if failures:
|
||||
for f in failures:
|
||||
print(f'{f}\n\n')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue