openai: refactor chat handler vs. template
This commit is contained in:
parent
3c3eff52aa
commit
6935503b53
4 changed files with 1133 additions and 567 deletions
|
@ -88,12 +88,9 @@ class ChatTemplate(BaseModel):
|
||||||
|
|
||||||
def probe_template_capabilities(self):
|
def probe_template_capabilities(self):
|
||||||
|
|
||||||
def test(messages: list[Message]):
|
|
||||||
return self._template.render(messages=messages, eos_token=self.eos_token, bos_token=self.bos_token, raise_exception=raise_exception, add_generation_prompt=True)
|
|
||||||
|
|
||||||
def succeeds(messages: list[Message], strings_to_find = ()):
|
def succeeds(messages: list[Message], strings_to_find = ()):
|
||||||
try:
|
try:
|
||||||
result = test(messages)
|
result = self.raw_render(messages, add_generation_prompt=True)
|
||||||
# print(result)
|
# print(result)
|
||||||
for s in strings_to_find:
|
for s in strings_to_find:
|
||||||
if s not in result:
|
if s not in result:
|
||||||
|
@ -133,8 +130,8 @@ class ChatTemplate(BaseModel):
|
||||||
|
|
||||||
delimiter = '<%$[SAMPLE]$%>'
|
delimiter = '<%$[SAMPLE]$%>'
|
||||||
user_msg = Message(role="user", content="Hey")
|
user_msg = Message(role="user", content="Hey")
|
||||||
empty_prompt = self.render([user_msg], add_generation_prompt=True).strip()
|
empty_prompt = self.raw_render([user_msg], add_generation_prompt=True).strip()
|
||||||
planted_prompt = self.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
|
planted_prompt = self.raw_render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
|
||||||
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
||||||
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
||||||
|
|
||||||
|
@ -181,10 +178,59 @@ class ChatTemplate(BaseModel):
|
||||||
bos_token = tokenizer.bos_token,
|
bos_token = tokenizer.bos_token,
|
||||||
eos_token = tokenizer.eos_token)
|
eos_token = tokenizer.eos_token)
|
||||||
|
|
||||||
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
def raw_render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
||||||
|
result = self._template.render(
|
||||||
|
messages=messages,
|
||||||
|
eos_token=self.eos_token,
|
||||||
|
bos_token='' if omit_bos else self.bos_token,
|
||||||
|
raise_exception=raise_exception,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
class ChatHandlerArgs(BaseModel):
|
||||||
|
chat_template: ChatTemplate
|
||||||
|
response_schema: Optional[dict] = None
|
||||||
|
tools: Optional[list[Tool]] = None
|
||||||
|
|
||||||
|
class ChatHandler(ABC):
|
||||||
|
def __init__(self, args: ChatHandlerArgs, style: Optional[ToolsPromptStyle]):
|
||||||
|
self.args = args
|
||||||
|
self.style = style
|
||||||
|
self.output_format_prompt: Optional[Message] = None
|
||||||
|
self.grammar: Optional[str] = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse(self, s: str) -> Optional[Message]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def render_prompt(self, messages: list[Message]) -> str:
|
||||||
def normalize(m: Message):
|
def normalize(m: Message):
|
||||||
|
if self.style == ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS:
|
||||||
if m.tool_calls:
|
if m.tool_calls:
|
||||||
if not self.formats_tool_call or not self.formats_tool_call_content:
|
m = Message(
|
||||||
|
role=m.role,
|
||||||
|
content=json.dumps({
|
||||||
|
_THOUGHT_KEY: m.content or '',
|
||||||
|
"next_step": {
|
||||||
|
"tool_calls": [tc.model_dump() for tc in m.tool_calls]
|
||||||
|
}
|
||||||
|
}, indent=2)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
m = Message(
|
||||||
|
role=m.role,
|
||||||
|
content=json.dumps({
|
||||||
|
_THOUGHT_KEY: '',
|
||||||
|
"next_step": {
|
||||||
|
"result": m.content
|
||||||
|
}
|
||||||
|
}, indent=2)
|
||||||
|
)
|
||||||
|
# Fall through to benefit from role normalization
|
||||||
|
|
||||||
|
if m.tool_calls:
|
||||||
|
if not self.args.chat_template.formats_tool_call or not self.args.chat_template.formats_tool_call_content:
|
||||||
return Message(
|
return Message(
|
||||||
role=m.role,
|
role=m.role,
|
||||||
content='\n'.join([
|
content='\n'.join([
|
||||||
|
@ -195,7 +241,7 @@ class ChatTemplate(BaseModel):
|
||||||
])
|
])
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
elif self.expects_stringified_function_arguments:
|
elif self.args.chat_template.expects_stringified_function_arguments:
|
||||||
return Message(
|
return Message(
|
||||||
role=m.role,
|
role=m.role,
|
||||||
content=m.content,
|
content=m.content,
|
||||||
|
@ -215,7 +261,7 @@ class ChatTemplate(BaseModel):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return m
|
return m
|
||||||
elif self.expects_strict_user_assistant_alternance and m.role not in ('user', 'assistant'):
|
elif self.args.chat_template.expects_strict_user_assistant_alternance and m.role not in ('user', 'assistant'):
|
||||||
if m.role == "system":
|
if m.role == "system":
|
||||||
return Message(role="user", content=f'[SYS]{m.content}[/SYS]')
|
return Message(role="user", content=f'[SYS]{m.content}[/SYS]')
|
||||||
elif m.role == "tool":
|
elif m.role == "tool":
|
||||||
|
@ -228,7 +274,7 @@ class ChatTemplate(BaseModel):
|
||||||
|
|
||||||
messages=[normalize(m) for m in messages]
|
messages=[normalize(m) for m in messages]
|
||||||
|
|
||||||
if self.expects_strict_user_assistant_alternance:
|
if self.args.chat_template.expects_strict_user_assistant_alternance:
|
||||||
new_messages=[]
|
new_messages=[]
|
||||||
current_role = 'user'
|
current_role = 'user'
|
||||||
current_content = []
|
current_content = []
|
||||||
|
@ -237,7 +283,7 @@ class ChatTemplate(BaseModel):
|
||||||
nonlocal current_content
|
nonlocal current_content
|
||||||
nonlocal current_role
|
nonlocal current_role
|
||||||
|
|
||||||
if self.expects_strict_user_assistant_alternance or current_content:
|
if self.args.chat_template.expects_strict_user_assistant_alternance or current_content:
|
||||||
new_messages.append(Message(
|
new_messages.append(Message(
|
||||||
role=current_role,
|
role=current_role,
|
||||||
content='\n'.join(current_content)
|
content='\n'.join(current_content)
|
||||||
|
@ -263,7 +309,7 @@ class ChatTemplate(BaseModel):
|
||||||
messages = [m.model_dump() for m in messages]
|
messages = [m.model_dump() for m in messages]
|
||||||
|
|
||||||
# if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
# if self.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
||||||
if self.expects_stringified_function_arguments:
|
if self.args.chat_template.expects_stringified_function_arguments:
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
**m,
|
**m,
|
||||||
|
@ -281,33 +327,14 @@ class ChatTemplate(BaseModel):
|
||||||
for m in messages
|
for m in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
result = self._template.render(
|
return self.args.chat_template.raw_render(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
eos_token=self.eos_token,
|
add_generation_prompt=True,
|
||||||
bos_token='' if omit_bos else self.bos_token,
|
|
||||||
raise_exception=raise_exception,
|
|
||||||
add_generation_prompt=add_generation_prompt,
|
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
|
|
||||||
class ChatHandlerArgs(BaseModel):
|
|
||||||
chat_template: ChatTemplate
|
|
||||||
response_schema: Optional[dict] = None
|
|
||||||
tools: Optional[list[Tool]] = None
|
|
||||||
|
|
||||||
class ChatHandler(ABC):
|
|
||||||
def __init__(self, args: ChatHandlerArgs):
|
|
||||||
self.args = args
|
|
||||||
self.output_format_prompt: Optional[Message] = None
|
|
||||||
self.grammar: Optional[str] = None
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def parse(self, s: str) -> Optional[Message]:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
class NoToolsChatHandler(ChatHandler):
|
class NoToolsChatHandler(ChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs):
|
def __init__(self, args: ChatHandlerArgs):
|
||||||
super().__init__(args)
|
super().__init__(args, None)
|
||||||
assert not args.tools
|
assert not args.tools
|
||||||
|
|
||||||
if args.response_schema:
|
if args.response_schema:
|
||||||
|
@ -327,8 +354,8 @@ class NoToolsChatHandler(ChatHandler):
|
||||||
return Message(role="assistant", content=s)
|
return Message(role="assistant", content=s)
|
||||||
|
|
||||||
class ToolCallTagsChatHandler(ChatHandler):
|
class ToolCallTagsChatHandler(ChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, parallel_calls: bool):
|
def __init__(self, args: ChatHandlerArgs, style: Optional[ToolsPromptStyle], escapes_underscores: bool, parallel_calls: bool):
|
||||||
super().__init__(args)
|
super().__init__(args, style)
|
||||||
|
|
||||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
tool_rules = []
|
tool_rules = []
|
||||||
|
@ -404,8 +431,8 @@ class ToolCallTagsChatHandler(ChatHandler):
|
||||||
|
|
||||||
|
|
||||||
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs, template: str, parallel_calls: bool, escapes_underscores: bool = False):
|
def __init__(self, args: ChatHandlerArgs, template: str, parallel_calls: bool, escapes_underscores: bool = False, style: Optional[ToolsPromptStyle] = None):
|
||||||
super().__init__(args, escapes_underscores=escapes_underscores, parallel_calls=parallel_calls)
|
super().__init__(args, style=style, escapes_underscores=escapes_underscores, parallel_calls=parallel_calls)
|
||||||
assert '{tools}' in template, 'Template must contain "{tools}"'
|
assert '{tools}' in template, 'Template must contain "{tools}"'
|
||||||
|
|
||||||
self.output_format_prompt = Message(
|
self.output_format_prompt = Message(
|
||||||
|
@ -418,7 +445,7 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
|
|
||||||
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
||||||
super().__init__(args, escapes_underscores=False, parallel_calls=parallel_calls)
|
super().__init__(args, style=ToolsPromptStyle.TOOLS_HERMES_2_PRO, escapes_underscores=False, parallel_calls=parallel_calls)
|
||||||
|
|
||||||
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
|
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
|
||||||
path = str(Path(__file__).parent / "hermes_function_calling")
|
path = str(Path(__file__).parent / "hermes_function_calling")
|
||||||
|
@ -434,7 +461,7 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
|
|
||||||
class FunctionaryToolsChatHandler(ChatHandler):
|
class FunctionaryToolsChatHandler(ChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
||||||
super().__init__(args)
|
super().__init__(args, ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2)
|
||||||
|
|
||||||
self.output_format_prompt = Message(
|
self.output_format_prompt = Message(
|
||||||
role="system",
|
role="system",
|
||||||
|
@ -541,9 +568,9 @@ def _make_bespoke_schema(response_schema, tool_call_schema, parallel_calls):
|
||||||
# "required": ["next_step"]
|
# "required": ["next_step"]
|
||||||
}
|
}
|
||||||
|
|
||||||
class BespokeToolsChatHandler(ChatHandler):
|
class ThoughtfulStepsToolsChatHandler(ChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
||||||
super().__init__(args)
|
super().__init__(args, ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS)
|
||||||
|
|
||||||
# args.response_schema = args.response_schema or {}
|
# args.response_schema = args.response_schema or {}
|
||||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
|
@ -660,7 +687,7 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
|
||||||
return NoToolsChatHandler(args)
|
return NoToolsChatHandler(args)
|
||||||
|
|
||||||
elif tool_style == ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS:
|
elif tool_style == ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS:
|
||||||
return BespokeToolsChatHandler(args, parallel_calls=parallel_calls)
|
return ThoughtfulStepsToolsChatHandler(args, parallel_calls=parallel_calls)
|
||||||
|
|
||||||
elif tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
elif tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
||||||
return FunctionaryToolsChatHandler(args, parallel_calls=parallel_calls)
|
return FunctionaryToolsChatHandler(args, parallel_calls=parallel_calls)
|
||||||
|
|
|
@ -140,7 +140,7 @@ def main(
|
||||||
if chat_handler.output_format_prompt:
|
if chat_handler.output_format_prompt:
|
||||||
messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
|
messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
|
||||||
|
|
||||||
prompt = chat_template.render(messages, add_generation_prompt=True)
|
prompt = chat_handler.render_prompt(messages)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')
|
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -165,27 +165,6 @@ if __name__ == "__main__":
|
||||||
check(chat_template.potentially_supports_parallel_calls == (model_name in MODELS_WITH_PARALLEL_CALLS),
|
check(chat_template.potentially_supports_parallel_calls == (model_name in MODELS_WITH_PARALLEL_CALLS),
|
||||||
f"{model_name} should {'not ' if model_name not in MODELS_WITH_PARALLEL_CALLS else ''} be detected as potentially supporting parallel calls")
|
f"{model_name} should {'not ' if model_name not in MODELS_WITH_PARALLEL_CALLS else ''} be detected as potentially supporting parallel calls")
|
||||||
|
|
||||||
# if model_name == 'hermes_2_pro_mistral':
|
|
||||||
# print("Skipping hermes_2_pro_mistral")
|
|
||||||
# continue
|
|
||||||
def check_finds(msgs, strings_to_find):
|
|
||||||
prompt = chat_template.render(msgs, add_generation_prompt=True)
|
|
||||||
for s in strings_to_find:
|
|
||||||
check(str(s) in prompt, f"Missing {s} in prompt for {model_name}:\n{prompt}")
|
|
||||||
|
|
||||||
check_finds([PROMPT_MESSAGE], (QUESTION,))
|
|
||||||
check_finds([ASSIST_MESSAGE], (ANSWER,))
|
|
||||||
check_finds([TOOL_CALL_MESSAGE], (TEST_ARG_A, TEST_ARG_B, TOOL_NAME))
|
|
||||||
check_finds([THOUGHTFUL_TOOL_CALL_MESSAGE], (TEST_THOUGHT, TEST_ARG_A, TEST_ARG_B, TOOL_NAME,))
|
|
||||||
check_finds([TOOL_MESSAGE], (TEST_SUM,))
|
|
||||||
if chat_template.potentially_supports_parallel_calls:
|
|
||||||
check_finds([TOOL_MESSAGE], (TOOL_NAME,))
|
|
||||||
|
|
||||||
print(f"\n# {model_name}\n")
|
|
||||||
print(f'\nTemplate:\n\n```js\n{chat_template.template}\n```\n')
|
|
||||||
|
|
||||||
print(f'\nPrompt:\n\n```js\n{chat_template.render(TEST_MESSAGES, add_generation_prompt=True)}\n```\n')
|
|
||||||
|
|
||||||
argss = {
|
argss = {
|
||||||
"with tools": ChatHandlerArgs(
|
"with tools": ChatHandlerArgs(
|
||||||
chat_template=chat_template, #ChatTemplate.from_gguf(GGUFKeyValues(model)),
|
chat_template=chat_template, #ChatTemplate.from_gguf(GGUFKeyValues(model)),
|
||||||
|
@ -199,6 +178,13 @@ if __name__ == "__main__":
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print(f"\n# {model_name}\n")
|
||||||
|
|
||||||
|
if chat_template.potentially_supports_parallel_calls:
|
||||||
|
print("\n**Might Support Parallel Tool Calls**\n")
|
||||||
|
|
||||||
|
print(f'\nTemplate:\n\n```js\n{chat_template.template}\n```\n')
|
||||||
|
|
||||||
for style in ToolsPromptStyle:
|
for style in ToolsPromptStyle:
|
||||||
if (style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2) != (model_name.startswith("functionary")):
|
if (style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2) != (model_name.startswith("functionary")):
|
||||||
continue
|
continue
|
||||||
|
@ -209,17 +195,39 @@ if __name__ == "__main__":
|
||||||
if model_name == "mistral_instruct_v0_1" and style not in (ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS, ToolsPromptStyle.TOOLS_MIXTRAL):
|
if model_name == "mistral_instruct_v0_1" and style not in (ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS, ToolsPromptStyle.TOOLS_MIXTRAL):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f'\n## {style}\n')
|
print(f'\n## {model_name} / {style.name}\n')
|
||||||
|
|
||||||
for tn, args in argss.items():
|
|
||||||
|
for tool_situation, args in argss.items():
|
||||||
ch = get_chat_handler(args, parallel_calls=True, tool_style=style)
|
ch = get_chat_handler(args, parallel_calls=True, tool_style=style)
|
||||||
|
|
||||||
print(f'\n### {tn}\n')
|
print(f'\n### {model_name} / {style.name} / {tool_situation}\n')
|
||||||
|
|
||||||
|
print(f'\nPrompt:\n\n```js\n{ch.render_prompt(TEST_MESSAGES)}\n```\n')
|
||||||
|
|
||||||
print(f'\nPrompt:\n\n```json\n{ch.output_format_prompt.content}\n```\n')
|
print(f'\nPrompt:\n\n```json\n{ch.output_format_prompt.content}\n```\n')
|
||||||
|
|
||||||
print(f'\nGrammar:\n\n```js\n{ch.grammar}\n```\n')
|
print(f'\nGrammar:\n\n```js\n{ch.grammar}\n```\n')
|
||||||
|
|
||||||
|
|
||||||
|
# if model_name == 'hermes_2_pro_mistral':
|
||||||
|
# print("Skipping hermes_2_pro_mistral")
|
||||||
|
# continue
|
||||||
|
def check_finds(msgs, strings_to_find):
|
||||||
|
prompt = ch.render_prompt(msgs)
|
||||||
|
for s in strings_to_find:
|
||||||
|
check(str(s) in prompt, f"Missing {s} in prompt for {model_name}:\n{prompt}")
|
||||||
|
|
||||||
|
check_finds([PROMPT_MESSAGE], (QUESTION,))
|
||||||
|
check_finds([ASSIST_MESSAGE], (ANSWER,))
|
||||||
|
check_finds([TOOL_CALL_MESSAGE], (TEST_ARG_A, TEST_ARG_B, TOOL_NAME))
|
||||||
|
check_finds([THOUGHTFUL_TOOL_CALL_MESSAGE], (TEST_THOUGHT, TEST_ARG_A, TEST_ARG_B, TOOL_NAME,))
|
||||||
|
check_finds([TOOL_MESSAGE], (TEST_SUM,))
|
||||||
|
if chat_template.potentially_supports_parallel_calls:
|
||||||
|
check_finds([TOOL_MESSAGE], (TOOL_NAME,))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if failures:
|
if failures:
|
||||||
for f in failures:
|
for f in failures:
|
||||||
print(f'{f}\n\n')
|
print(f'{f}\n\n')
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue