openai: refactor chat handler vs. template

This commit is contained in:
ochafik 2024-03-30 01:50:36 +00:00
parent 3c3eff52aa
commit 6935503b53
4 changed files with 1133 additions and 567 deletions

View file

@ -88,12 +88,9 @@ class ChatTemplate(BaseModel):
def probe_template_capabilities(self):
def test(messages: list[Message]):
return self._template.render(messages=messages, eos_token=self.eos_token, bos_token=self.bos_token, raise_exception=raise_exception, add_generation_prompt=True)
def succeeds(messages: list[Message], strings_to_find = ()):
try:
result = test(messages)
result = self.raw_render(messages, add_generation_prompt=True)
# print(result)
for s in strings_to_find:
if s not in result:
@ -133,8 +130,8 @@ class ChatTemplate(BaseModel):
delimiter = '<%$[SAMPLE]$%>'
user_msg = Message(role="user", content="Hey")
empty_prompt = self.render([user_msg], add_generation_prompt=True).strip()
planted_prompt = self.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
empty_prompt = self.raw_render([user_msg], add_generation_prompt=True).strip()
planted_prompt = self.raw_render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
@ -181,10 +178,59 @@ class ChatTemplate(BaseModel):
bos_token = tokenizer.bos_token,
eos_token = tokenizer.eos_token)
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
def raw_render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
result = self._template.render(
messages=messages,
eos_token=self.eos_token,
bos_token='' if omit_bos else self.bos_token,
raise_exception=raise_exception,
add_generation_prompt=add_generation_prompt,
)
return result
class ChatHandlerArgs(BaseModel):
chat_template: ChatTemplate
response_schema: Optional[dict] = None
tools: Optional[list[Tool]] = None
class ChatHandler(ABC):
def __init__(self, args: ChatHandlerArgs, style: Optional[ToolsPromptStyle]):
self.args = args
self.style = style
self.output_format_prompt: Optional[Message] = None
self.grammar: Optional[str] = None
@abstractmethod
def parse(self, s: str) -> Optional[Message]:
raise NotImplementedError()
def render_prompt(self, messages: list[Message]) -> str:
def normalize(m: Message):
if self.style == ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS:
if m.tool_calls:
if not self.formats_tool_call or not self.formats_tool_call_content:
m = Message(
role=m.role,
content=json.dumps({
_THOUGHT_KEY: m.content or '',
"next_step": {
"tool_calls": [tc.model_dump() for tc in m.tool_calls]
}
}, indent=2)
)
else:
m = Message(
role=m.role,
content=json.dumps({
_THOUGHT_KEY: '',
"next_step": {
"result": m.content
}
}, indent=2)
)
# Fall through to benefit from role normalization
if m.tool_calls:
if not self.args.chat_template.formats_tool_call or not self.args.chat_template.formats_tool_call_content:
return Message(
role=m.role,
content='\n'.join([
@ -195,7 +241,7 @@ class ChatTemplate(BaseModel):
])
])
)
elif self.expects_stringified_function_arguments:
elif self.args.chat_template.expects_stringified_function_arguments:
return Message(
role=m.role,
content=m.content,
@ -215,7 +261,7 @@ class ChatTemplate(BaseModel):
)
else:
return m
elif self.expects_strict_user_assistant_alternance and m.role not in ('user', 'assistant'):
elif self.args.chat_template.expects_strict_user_assistant_alternance and m.role not in ('user', 'assistant'):
if m.role == "system":
return Message(role="user", content=f'[SYS]{m.content}[/SYS]')
elif m.role == "tool":
@ -228,7 +274,7 @@ class ChatTemplate(BaseModel):
messages=[normalize(m) for m in messages]
if self.expects_strict_user_assistant_alternance:
if self.args.chat_template.expects_strict_user_assistant_alternance:
new_messages=[]
current_role = 'user'
current_content = []
@ -237,7 +283,7 @@ class ChatTemplate(BaseModel):
nonlocal current_content
nonlocal current_role
if self.expects_strict_user_assistant_alternance or current_content:
if self.args.chat_template.expects_strict_user_assistant_alternance or current_content:
new_messages.append(Message(
role=current_role,
content='\n'.join(current_content)
@ -263,7 +309,7 @@ class ChatTemplate(BaseModel):
messages = [m.model_dump() for m in messages]
# if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
if self.expects_stringified_function_arguments:
if self.args.chat_template.expects_stringified_function_arguments:
messages = [
{
**m,
@ -281,33 +327,14 @@ class ChatTemplate(BaseModel):
for m in messages
]
result = self._template.render(
return self.args.chat_template.raw_render(
messages=messages,
eos_token=self.eos_token,
bos_token='' if omit_bos else self.bos_token,
raise_exception=raise_exception,
add_generation_prompt=add_generation_prompt,
add_generation_prompt=True,
)
return result
class ChatHandlerArgs(BaseModel):
chat_template: ChatTemplate
response_schema: Optional[dict] = None
tools: Optional[list[Tool]] = None
class ChatHandler(ABC):
def __init__(self, args: ChatHandlerArgs):
self.args = args
self.output_format_prompt: Optional[Message] = None
self.grammar: Optional[str] = None
@abstractmethod
def parse(self, s: str) -> Optional[Message]:
raise NotImplementedError()
class NoToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs):
super().__init__(args)
super().__init__(args, None)
assert not args.tools
if args.response_schema:
@ -327,8 +354,8 @@ class NoToolsChatHandler(ChatHandler):
return Message(role="assistant", content=s)
class ToolCallTagsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, parallel_calls: bool):
super().__init__(args)
def __init__(self, args: ChatHandlerArgs, style: Optional[ToolsPromptStyle], escapes_underscores: bool, parallel_calls: bool):
super().__init__(args, style)
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
tool_rules = []
@ -404,8 +431,8 @@ class ToolCallTagsChatHandler(ChatHandler):
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs, template: str, parallel_calls: bool, escapes_underscores: bool = False):
super().__init__(args, escapes_underscores=escapes_underscores, parallel_calls=parallel_calls)
def __init__(self, args: ChatHandlerArgs, template: str, parallel_calls: bool, escapes_underscores: bool = False, style: Optional[ToolsPromptStyle] = None):
super().__init__(args, style=style, escapes_underscores=escapes_underscores, parallel_calls=parallel_calls)
assert '{tools}' in template, 'Template must contain "{tools}"'
self.output_format_prompt = Message(
@ -418,7 +445,7 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args, escapes_underscores=False, parallel_calls=parallel_calls)
super().__init__(args, style=ToolsPromptStyle.TOOLS_HERMES_2_PRO, escapes_underscores=False, parallel_calls=parallel_calls)
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
path = str(Path(__file__).parent / "hermes_function_calling")
@ -434,7 +461,7 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
class FunctionaryToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args)
super().__init__(args, ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2)
self.output_format_prompt = Message(
role="system",
@ -541,9 +568,9 @@ def _make_bespoke_schema(response_schema, tool_call_schema, parallel_calls):
# "required": ["next_step"]
}
class BespokeToolsChatHandler(ChatHandler):
class ThoughtfulStepsToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args)
super().__init__(args, ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS)
# args.response_schema = args.response_schema or {}
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
@ -660,7 +687,7 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
return NoToolsChatHandler(args)
elif tool_style == ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS:
return BespokeToolsChatHandler(args, parallel_calls=parallel_calls)
return ThoughtfulStepsToolsChatHandler(args, parallel_calls=parallel_calls)
elif tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
return FunctionaryToolsChatHandler(args, parallel_calls=parallel_calls)

View file

@ -140,7 +140,7 @@ def main(
if chat_handler.output_format_prompt:
messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
prompt = chat_template.render(messages, add_generation_prompt=True)
prompt = chat_handler.render_prompt(messages)
if verbose:
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')

File diff suppressed because it is too large Load diff

View file

@ -165,27 +165,6 @@ if __name__ == "__main__":
check(chat_template.potentially_supports_parallel_calls == (model_name in MODELS_WITH_PARALLEL_CALLS),
f"{model_name} should {'not ' if model_name not in MODELS_WITH_PARALLEL_CALLS else ''} be detected as potentially supporting parallel calls")
# if model_name == 'hermes_2_pro_mistral':
# print("Skipping hermes_2_pro_mistral")
# continue
def check_finds(msgs, strings_to_find):
prompt = chat_template.render(msgs, add_generation_prompt=True)
for s in strings_to_find:
check(str(s) in prompt, f"Missing {s} in prompt for {model_name}:\n{prompt}")
check_finds([PROMPT_MESSAGE], (QUESTION,))
check_finds([ASSIST_MESSAGE], (ANSWER,))
check_finds([TOOL_CALL_MESSAGE], (TEST_ARG_A, TEST_ARG_B, TOOL_NAME))
check_finds([THOUGHTFUL_TOOL_CALL_MESSAGE], (TEST_THOUGHT, TEST_ARG_A, TEST_ARG_B, TOOL_NAME,))
check_finds([TOOL_MESSAGE], (TEST_SUM,))
if chat_template.potentially_supports_parallel_calls:
check_finds([TOOL_MESSAGE], (TOOL_NAME,))
print(f"\n# {model_name}\n")
print(f'\nTemplate:\n\n```js\n{chat_template.template}\n```\n')
print(f'\nPrompt:\n\n```js\n{chat_template.render(TEST_MESSAGES, add_generation_prompt=True)}\n```\n')
argss = {
"with tools": ChatHandlerArgs(
chat_template=chat_template, #ChatTemplate.from_gguf(GGUFKeyValues(model)),
@ -199,6 +178,13 @@ if __name__ == "__main__":
),
}
print(f"\n# {model_name}\n")
if chat_template.potentially_supports_parallel_calls:
print("\n**Might Support Parallel Tool Calls**\n")
print(f'\nTemplate:\n\n```js\n{chat_template.template}\n```\n')
for style in ToolsPromptStyle:
if (style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2) != (model_name.startswith("functionary")):
continue
@ -209,17 +195,39 @@ if __name__ == "__main__":
if model_name == "mistral_instruct_v0_1" and style not in (ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS, ToolsPromptStyle.TOOLS_MIXTRAL):
continue
print(f'\n## {style}\n')
print(f'\n## {model_name} / {style.name}\n')
for tn, args in argss.items():
for tool_situation, args in argss.items():
ch = get_chat_handler(args, parallel_calls=True, tool_style=style)
print(f'\n### {tn}\n')
print(f'\n### {model_name} / {style.name} / {tool_situation}\n')
print(f'\nPrompt:\n\n```js\n{ch.render_prompt(TEST_MESSAGES)}\n```\n')
print(f'\nPrompt:\n\n```json\n{ch.output_format_prompt.content}\n```\n')
print(f'\nGrammar:\n\n```js\n{ch.grammar}\n```\n')
# if model_name == 'hermes_2_pro_mistral':
# print("Skipping hermes_2_pro_mistral")
# continue
def check_finds(msgs, strings_to_find):
prompt = ch.render_prompt(msgs)
for s in strings_to_find:
check(str(s) in prompt, f"Missing {s} in prompt for {model_name}:\n{prompt}")
check_finds([PROMPT_MESSAGE], (QUESTION,))
check_finds([ASSIST_MESSAGE], (ANSWER,))
check_finds([TOOL_CALL_MESSAGE], (TEST_ARG_A, TEST_ARG_B, TOOL_NAME))
check_finds([THOUGHTFUL_TOOL_CALL_MESSAGE], (TEST_THOUGHT, TEST_ARG_A, TEST_ARG_B, TOOL_NAME,))
check_finds([TOOL_MESSAGE], (TEST_SUM,))
if chat_template.potentially_supports_parallel_calls:
check_finds([TOOL_MESSAGE], (TOOL_NAME,))
if failures:
for f in failures:
print(f'{f}\n\n')