server.py: refactor chat handlers

This commit is contained in:
ochafik 2024-03-29 02:47:33 +00:00
parent 5f3de16116
commit 59b411406f
3 changed files with 487 additions and 321 deletions

View file

@ -18,7 +18,7 @@ class Message(BaseModel):
class ToolFunction(BaseModel):
name: str
description: str
parameters: Any
parameters: dict[str, Any]
class Tool(BaseModel):
type: str

View file

@ -1,11 +1,14 @@
from abc import ABC, abstractmethod
from enum import Enum
from functools import wraps
import jinja2
import json
from pathlib import Path
import random
import re
import sys
from typing import Optional, Tuple, Callable
from typing import Any, Dict, Literal, Optional, Tuple, Callable, Union
from pydantic import BaseModel
from typeguard import typechecked
from examples.json_schema_to_grammar import SchemaConverter
@ -18,22 +21,52 @@ def raise_exception(msg: str):
raise Exception(msg)
@typechecked
class ChatFormat:
def __init__(self, template: str, eos_token: str, bos_token: str):
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
self.template = env.from_string(template)
self.eos_token = eos_token
self.bos_token = bos_token
class ChatTemplate(BaseModel):
template: str
self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template
@property
def tool_style(self) -> 'ToolsPromptStyle':
return self._tool_style
def __init__(self, template: str, eos_token: str, bos_token: str):
super().__init__(template=template
)
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
self._template = env.from_string(template)
self._eos_token = eos_token
self._bos_token = bos_token
self._strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template
if "<|recipient|>' + tool_call['function']['name']" in template:
self.tool_style = ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2
self._tool_style = ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2
else:
self.tool_style = ToolsPromptStyle.TOOLS_LONG
self._tool_style = ToolsPromptStyle.TOOLS_BESPOKE
# self._tool_style = ToolsPromptStyle.TOOLS_LONG
# TODO: Test whether the template supports formatting tool_calls
delimiter = '<%$[SAMPLE]$%>'
user_msg = Message(role="user", content="Hey")
empty_prompt = self.render([user_msg], add_generation_prompt=True).strip()
planted_prompt = self.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
sys.stderr.write(f"\n# prefix={prefix}\n# suffix={suffix}\n\n")
self._prefix = prefix
self._suffix = suffix
def strip_suffix(self, s: str) -> str:
if s.endswith(self._suffix):
return s[:-len(self._suffix)]
else:
sys.stderr.write(f"Expected suffix ({self._suffix}) not found: {s}\n")
return s
def __str__(self):
return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})"
return f"ChatTemplate(template={self.template}, eos_token={self._eos_token}, bos_token={self._bos_token})"
def add_system_prompt(self, messages: list[Message], system_prompt: Message) -> list[Message]:
assert system_prompt.role == "system"
@ -48,13 +81,13 @@ class ChatFormat:
@staticmethod
def from_gguf(metadata: GGUFKeyValues):
tokens = metadata[Keys.Tokenizer.LIST]
return ChatFormat(
return ChatTemplate(
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
bos_token = tokens[metadata[Keys.Tokenizer.BOS_ID]],
eos_token = tokens[metadata[Keys.Tokenizer.EOS_ID]])
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
if self.strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
new_messages=[]
i = 0
n = len(messages)
@ -80,10 +113,10 @@ class ChatFormat:
messages = new_messages
# print(f'messages={messages}')
result = self.template.render(
result = self._template.render(
messages=messages,
eos_token=self.eos_token,
bos_token='' if omit_bos else self.bos_token,
eos_token=self._eos_token,
bos_token='' if omit_bos else self._bos_token,
raise_exception=raise_exception,
add_generation_prompt=add_generation_prompt,
)
@ -95,67 +128,176 @@ class ChatFormat:
# each model may need specific prompting (and/or constrained output,
# especially for models not fine-tuned for tool usage / function calling).
class ToolsPromptStyle(Enum):
# Short prompt w/ <tools>schemas</tools>
# Short prompt w/ <tools>schemas</tools>, <tool_call>...</tool_call> output
TOOLS_SHORT = 1
# Longer prompt w/ <tools>schemas</tools>
# Longer prompt w/ <tools>schemas</tools>, <tool_call>...</tool_call> output
TOOLS_LONG = 2
# Bespoke constrained output format that favours thought and reasoning
# while allowing unambiguous parsing of parallel tool calling.
TOOLS_BESPOKE = 3
# Large prompt for https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B
# <tool_call>...</tool_call> output
# Requires:
# - git clone https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling
# - Set large context length as their prompts are super long
TOOLS_HERMES_2_PRO = 3
TOOLS_HERMES_2_PRO = 4
# Seems to want to escape underscores in tool names and in the <tool\_call>...</tool\_call> tags
TOOLS_MISTRAL = 5
# Short prompt w/ TypeScript definitions for https://github.com/MeetKai/functionary
# https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py
# Note: see this prior attempt to support Functionary: https://github.com/ggerganov/llama.cpp/pull/5695
TYPESCRIPT_FUNCTIONARY_V2 = 4
TYPESCRIPT_FUNCTIONARY_V2 = 6
@typechecked
def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> Message:
class ChatHandlerArgs(BaseModel):
chat_template: ChatTemplate
response_schema: Optional[dict] = None
tools: Optional[list[Tool]] = None
if chat_format.tool_style == ToolsPromptStyle.TOOLS_SHORT:
return Message(
class ChatHandler(ABC):
def __init__(self, args: ChatHandlerArgs):
self.args = args
self.output_format_prompt: Optional[Message] = None
self.grammar: Optional[str] = None
@abstractmethod
def parse(self, s: str) -> Optional[Message]:
raise NotImplementedError()
class NoToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs):
super().__init__(args)
assert not args.tools
if args.response_schema:
self.output_format_prompt = Message(
role="system",
content=_please_respond_with_schema(args.response_schema)
)
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
self.grammar = converter.visit(args.response_schema, '')
else:
self.output_format_prompt = None
self.grammar = None
@typechecked
def parse(self, s: str) -> Optional[Message]:
return Message(role="assistant", content=s)
class ToolCallTagsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, allow_parallel_calls: bool):
super().__init__(args)
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
tool_rules = [
converter.visit(
dict(
type="object",
properties=dict(
name=dict(type="string", pattern='^' + tool.function.name.replace('_', f'\\?_') + '$') if escapes_underscores \
else dict(const=tool.function.name),
arguments=tool.function.parameters,
),
required=['name', 'arguments']
),
f'{tool.function.name}-tool-call'
)
for tool in self.args.tools
]
def format_literal(s: str) -> str:
if escapes_underscores:
return ' "\\\\"? "_" '.join((converter._format_literal(part) for part in s.split('_')))
else:
return converter._format_literal(s)
tool_call_rule = converter._add_rule(
'tool_call',
format_literal("<tool_call>") + " space (" +
' | '.join(tool_rules) +
") space " + format_literal("</tool_call>"))# + ' space')
# Ideally we'd want a negative lookahead of /<tool\\?_call>/, but it's just too hard to express in GBNF for now.
# So we just over-constrain the content rule to not contain literals dangerously getting close to <tool_call>
content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "<t" [^o<]')
# content_rule = converter._add_rule('content', converter.not_literal('<tool_call>'))
converter._add_rule(
'root',
# tool_call_rule)
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \
else f'{content_rule}* {tool_call_rule}?')
self.grammar = converter.format_grammar()
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
# # OR a tool-call message respecting the schema of any of the tools
# converter._add_rule(
# "root",
# converter._format_literal(prefix) + " (" +
# (response_rule or converter.not_literal("<tool_call>")) + " | " +
# converter._format_literal("<tool_call>") + " (" +
# ' | '.join(tool_rules) +
# ") " + converter._format_literal("</tool_call>") +
# ")") # + converter._format_literal(suffix))
@typechecked
def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s)
if r'<tool\_call>' in s:
# Some weird escaping of underscores is happening w/ Mixtral 8x7B Instruct
s = s.replace(r'\_', '_')
parts = _tool_call_re.split(s)
if len(parts) == 1:
return Message(role="assistant", content=s)
else:
content = []
tool_calls = []
for i, part in enumerate(parts):
if i % 2 == 0:
content.append(part)
else:
try:
fc = json.loads(part)
except json.JSONDecodeError:
raise ValueError(f'Failed to parse tool call as JSON: {part}\nFull string: {s}')
tool_calls.append(
ToolCall(
id=gen_callid(),
function=FunctionCall(**fc)))
content = '\n'.join(content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
# return None
# else:
# return Message(role="assistant", content=s)
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs, template: str, escapes_underscores=False, allow_parallel_calls=True):
super().__init__(args, escapes_underscores=escapes_underscores, allow_parallel_calls=allow_parallel_calls)
assert '{tools}' in template, 'Template must contain "{tools}"'
self.output_format_prompt = Message(
role="system",
content='\n'.join([
'Here are the tools available:',
'<tools>',
*(json.dumps(tool.model_dump(), indent=indent) for tool in tools),
'</tools>',
])
content=template.replace(
'{tools}',
'\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in self.args.tools),
)
)
elif chat_format.tool_style == ToolsPromptStyle.TOOLS_LONG:
return Message(
role="system",
content='\n'.join([
# '''You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.''',
'''You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:''',
'''<tools>''',
_tools_typescript_signatures(tools),
# _tools_schema_signatures(tools, indent=indent),
'''</tools>''',
'',
# '''Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}''',
# '',
# '''For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:''',
'''To call each function, give its name and arguments within <tool_call></tool_call> XML tags as follows:''',
'''<tool_call>''',
'''{"name": <function-name>, "arguments": <args-dict>}''',
'''</tool_call>''',
# '''This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
])
)
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
return Message(
role="system",
content= '// Supported function definitions that should be called when necessary.\n' +
_tools_typescript_signatures(tools)
)
elif chat_format.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs):
super().__init__(args, escapes_underscores=False, allow_parallel_calls=False)
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
path = str(Path(__file__).parent / "hermes_function_calling")
if path not in sys.path: sys.path.insert(0, path)
@ -166,16 +308,276 @@ def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> M
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools])
assert len(prompt) == 1 and prompt[0]["role"] == "system"
return Message(**prompt[0])
self.output_format_prompt = Message(**prompt[0])
class FunctionaryToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
super().__init__(args)
# Only allowing a single tool call at a time for now.
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal
self.output_format_prompt = Message(
role="system",
content= '// Supported function definitions that should be called when necessary.\n' +
_tools_typescript_signatures(args.tools)
)
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
tool_rules = [
converter._add_rule(
tool.function.name + '-call',
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
converter._format_literal('\n'))
# converter.visit(
# dict(
# type="object",
# properties=dict(
# name=dict(const=tool.function.name),
# arguments=tool.function.parameters,
# ),
# required=['name', 'arguments']
# ),
# f'{tool.function.name}-tool-call'
# )
for i, tool in enumerate(self.args.tools)
]
not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
content_without_start_rule = converter._add_rule(
'content_without_start',
converter._format_literal("all\n<|content|>") + ' ' + not_from_rule + '*')
start_rule = converter._add_rule('start', converter._format_literal('<|from|>assistant\n<|recipient|>'))
content_rule = converter._add_rule('content', start_rule + ' ' + content_without_start_rule)
tool_call_without_start_rule = converter._add_rule(
'tool_call_without_start',
' | '.join(tool_rules))
# + ' ' +
# converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
# converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
converter._add_rule(
'root',
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
self.grammar = converter.format_grammar()
# converter._add_rule(
# "root",
# converter._format_literal(prefix) + " (" +
# (response_rule or converter.not_literal("<|recipient|>")) + " | " +
# (' | '.join(
# converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
# converter.visit(tool.function.parameters, tool.function.name + '-args')
# for tool in tools
# )) +
# ") " +
# ")") # + converter._format_literal(suffix))
@typechecked
def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s)
parts = _recipient_content_re.split(s)
if len(parts) == 1:
return Message(role="assistant", content=s)
else:
text_content = []
tool_calls: list[ToolCall] = []
for i in range((len(parts) - 1) // 3):
assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}'
recipient = parts[i * 3 + 1].strip()
content = parts[i * 3 + 2]
if recipient == 'all':
text_content.append(content)
else:
try:
arguments = json.loads(content)
except json.JSONDecodeError:
raise ValueError(f'Failed to parse tool call content as JSON: {content}')
tool_calls.append(
ToolCall(
id=gen_callid(),
function=FunctionCall(name=recipient, arguments=arguments)))
assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}'
content = '\n'.join(text_content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
def _make_bespoke_schema(response_schema, tool_call_schema):
return {
"type": "object",
"properties": {
# "original_goal": {"title": "Original Goal", "type": "string"},
"thought": {
# "title": "Thought about how the next step brings us closer to achieving the original goal",
"type": "string"
},
"next_step": {
"title": "Next Step: either a result or one or more tool calls to achieve the original goal",
"oneOf": [
{
"title": "Tool Calls",
"properties": {
# "type": {
# "const": "tool_calls"
# },
"tool_calls": {
"type": "array",
"items": tool_call_schema
}
},
"required": ["tool_calls"]
},
{
"title": "Result (achieving original goal)",
"properties": {
"result": response_schema,
},
"required": ["result"]
},
]
},
},
"required": ["original_goal", "thought", "next_step"]
}
class BespokeToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs):
super().__init__(args)
# args.response_schema = args.response_schema or {}
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
response_schema = args.response_schema or {"type": "string"}
converter.visit(
_make_bespoke_schema(
response_schema,
{
"oneOf": [
{
"type": "object",
"properties": {
"name": {"const": tool.function.name},
"arguments": tool.function.parameters,
},
"required": ["name", "arguments"]
}
for tool in self.args.tools
]
}
),
'',
)
self.grammar = converter.format_grammar()
self.output_format_prompt = Message(
role="system",
content='\n'.join([
'You are a function calling AI model.',
'Here are the tools available:',
_tools_schema_signatures(self.args.tools, indent=2),
_please_respond_with_schema(
_make_bespoke_schema(
response_schema,
{
"properties": {
"name": {
"title": "Name of the tool to call",
"type": "string"
},
"arguments": {
"title": "Arguments to pass to the tool",
"type": "object"
}
},
"required": ["name", "arguments"]
}
)
),
])
)
@typechecked
def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s)
try:
data = json.loads(s)
except json.JSONDecodeError:
raise ValueError(f'Failed to parse data as JSON: {s}')
next_step = data['next_step']
if 'result' in next_step:
return Message(role="assistant", content=json.dumps(next_step['result']))
elif 'tool_calls' in next_step:
return Message(
role="assistant",
content=data["thought"],
tool_calls=[
ToolCall(id=gen_callid(), function=FunctionCall(**tc))
for tc in next_step['tool_calls']
]
)
else:
raise ValueError(f'Unexpected data: {data}')
_SHORT_TEMPLATE='\n'.join([
'Here are the tools available:',
'<tools>',
'{tools}',
'</tools>',
])
_LONG_TEMPLATE='\n'.join([
# '''You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.''',
'You may call one or more functions to assist with the user query. Don\'t make assumptions about what values to plug into functions. Here are the available tools:',
'<tools>',
'{tools}',
'</tools>',
'',
# 'Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}',
# '',
# 'For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:',
'To call each function, give its name and arguments within <tool_call></tool_call> XML tags as follows:',
'<tool_call>',
'{"name": <function-name>, "arguments": <args-dict>}',
'</tool_call>',
# 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
])
def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatHandler:
if not args.tools:
return NoToolsChatHandler(args)
elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
return FunctionaryToolsChatHandler(args)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT:
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
return BespokeToolsChatHandler(args)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
return Hermes2ProToolsChatHandler(args)
else:
raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}")
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
_ts_converter = SchemaToTypeScriptConverter()
def _please_respond_with_schema(schema: dict) -> str:
# sig = json.dumps(schema, indent=2)
sig = _ts_converter.visit(schema)
return f'Please respond in JSON format with the following schema: {sig}'
def _tools_typescript_signatures(tools: list[Tool]) -> str:
ts_converter = SchemaToTypeScriptConverter()
return 'namespace functions {' + '\n'.join(
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n"
'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"
for tool in tools
) + '} // namespace functions'
@ -185,247 +587,9 @@ def _tools_schema_signatures(tools: list[Tool], indent=None) -> str:
for tool in tools
)
@typechecked
def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
return style in (
ToolsPromptStyle.TOOLS_SHORT,
ToolsPromptStyle.TOOLS_LONG,
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
)
_tool_call_re = re.compile(
'<tool_call>(.*?)</tool_call>', re.DOTALL)
_recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL)
def gen_callid():
return f'call_{random.randint(0, 1000000)}'
@typechecked
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[list[Message]]]]:
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
response_rule = converter.visit(response_schema, "response") if response_schema else None
delimiter = '<%$[SAMPLE]$%>'
user_msg = Message(role="user", content="Hey")
empty_prompt = chat_format.render([user_msg], add_generation_prompt=True).strip()
planted_prompt = chat_format.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
allow_parallel_calls = False
def strip_suffix(s: str) -> str:
if s.endswith(suffix):
return s[:-len(suffix)]
else:
sys.stderr.write(f"Expected suffix ({suffix}) not found: {s}\n")
return s
if tools:
if _outputs_tool_call_tags(chat_format.tool_style):
escapes_underscores = chat_format.tool_style != ToolsPromptStyle.TOOLS_HERMES_2_PRO
tool_rules = [
converter.visit(
dict(
type="object",
properties=dict(
name=dict(type="string", pattern='^' + tool.function.name.replace('_', f'\\?_') + '$') if escapes_underscores \
else dict(const=tool.function.name),
arguments=tool.function.parameters,
),
required=['name', 'arguments']
),
f'{tool.function.name}-tool-call'
)
for tool in tools
]
def format_literal(s: str) -> str:
if escapes_underscores:
return ' "\\\\"? "_" '.join((converter._format_literal(part) for part in s.split('_')))
else:
return converter._format_literal(s)
tool_call_rule = converter._add_rule(
'tool_call',
format_literal("<tool_call>") + " space (" +
' | '.join(tool_rules) +
") space " + format_literal("</tool_call>"))# + ' space')
# Ideally we'd want a negative lookahead of /<tool\\?_call>/, but it's just too hard to express in GBNF for now.
# So we just over-constrain the content rule to not contain literals dangerously getting close to <tool_call>
content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "<t" [^o<]')
# content_rule = converter._add_rule('content', converter.not_literal('<tool_call>'))
converter._add_rule(
'root',
# tool_call_rule)
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \
else f'{content_rule}* {tool_call_rule}?')
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
# # OR a tool-call message respecting the schema of any of the tools
# converter._add_rule(
# "root",
# converter._format_literal(prefix) + " (" +
# (response_rule or converter.not_literal("<tool_call>")) + " | " +
# converter._format_literal("<tool_call>") + " (" +
# ' | '.join(tool_rules) +
# ") " + converter._format_literal("</tool_call>") +
# ")") # + converter._format_literal(suffix))
@typechecked
def parse(s: str) -> Optional[Message]:
s = strip_suffix(s)
if r'<tool\_call>' in s:
# Some weird escaping of underscores is happening w/ Mixtral 8x7B Instruct
s = s.replace(r'\_', '_')
parts = _tool_call_re.split(s)
if len(parts) == 1:
return Message(role="assistant", content=s)
else:
content = []
tool_calls = []
for i, part in enumerate(parts):
if i % 2 == 0:
content.append(part)
else:
try:
fc = json.loads(part)
except json.JSONDecodeError:
raise ValueError(f'Failed to parse tool call as JSON: {part}\nFull string: {s}')
tool_calls.append(
ToolCall(
id=gen_callid(),
function=FunctionCall(**fc)))
content = '\n'.join(content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
# return None
# else:
# return Message(role="assistant", content=s)
return (converter.format_grammar(), parse)
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
# Only allowing a single tool call at a time for now.
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal
tool_rules = [
converter._add_rule(
tool.function.name + '-call',
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
converter._format_literal('\n'))
# converter.visit(
# dict(
# type="object",
# properties=dict(
# name=dict(const=tool.function.name),
# arguments=tool.function.parameters,
# ),
# required=['name', 'arguments']
# ),
# f'{tool.function.name}-tool-call'
# )
for i, tool in enumerate(tools)
]
not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
content_without_start_rule = converter._add_rule(
'content_without_start',
converter._format_literal("all\n<|content|>") + ' ' + not_from_rule + '*')
start_rule = converter._add_rule('start', converter._format_literal('<|from|>assistant\n<|recipient|>'))
content_rule = converter._add_rule('content', start_rule + ' ' + content_without_start_rule)
tool_call_without_start_rule = converter._add_rule(
'tool_call_without_start',
' | '.join(tool_rules))
# + ' ' +
# converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
# converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
converter._add_rule(
'root',
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
# converter._add_rule(
# "root",
# converter._format_literal(prefix) + " (" +
# (response_rule or converter.not_literal("<|recipient|>")) + " | " +
# (' | '.join(
# converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
# converter.visit(tool.function.parameters, tool.function.name + '-args')
# for tool in tools
# )) +
# ") " +
# ")") # + converter._format_literal(suffix))
@typechecked
def parse(s: str) -> Optional[Message]:
s = strip_suffix(s)
parts = _recipient_content_re.split(s)
if len(parts) == 1:
return Message(role="assistant", content=s)
else:
text_content = []
tool_calls: list[ToolCall] = []
for i in range((len(parts) - 1) // 3):
assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}'
recipient = parts[i * 3 + 1].strip()
content = parts[i * 3 + 2]
if recipient == 'all':
text_content.append(content)
else:
try:
arguments = json.loads(content)
except json.JSONDecodeError:
raise ValueError(f'Failed to parse tool call content as JSON: {content}')
tool_calls.append(
ToolCall(
id=gen_callid(),
function=FunctionCall(name=recipient, arguments=arguments)))
assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}'
content = '\n'.join(text_content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
return (converter.format_grammar(), parse)
else:
raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}")
elif response_schema:
converter._add_rule("root", response_rule + ' ' + converter._format_literal(suffix))
@typechecked
def parse(s: str) -> Optional[Message]:
s = strip_suffix(s)
return Message(role="assistant", content=s)
return (converter.format_grammar(), parse)
else:
converter._add_rule("root", converter._format_literal(prefix) + ' ' + converter._format_literal(suffix))
@typechecked
def parse(s: str) -> Optional[Message]:
s = strip_suffix(s)
return Message(role="assistant", content=s)
return (None, parse)

View file

@ -5,12 +5,14 @@ import json, sys, subprocess, atexit
from pathlib import Path
import time
from pydantic import TypeAdapter
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage
from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, get_chat_handler, ChatHandler
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
@ -38,8 +40,8 @@ def main(
metadata = GGUFKeyValues(model)
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
chat_format = ChatFormat.from_gguf(metadata)
# print(chat_format)
chat_template = ChatTemplate.from_gguf(metadata)
# print(chat_template)
if not cpp_server_endpoint:
sys.stderr.write(f"# Starting C++ server with model {model} on {cpp_server_host}:{cpp_server_port}\n")
@ -69,18 +71,17 @@ def main(
else:
response_schema = None
messages = chat_request.messages
if chat_request.tools:
messages = chat_format.add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
# TODO: Test whether the template supports formatting tool_calls
prompt = chat_format.render(messages, add_generation_prompt=True)
chat_handler = get_chat_handler(ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools))
messages = chat_request.messages
if chat_handler.output_format_prompt:
messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
prompt = chat_template.render(messages, add_generation_prompt=True)
sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n')
sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n')
sys.stderr.write(f'\n# GRAMMAR:\n\n{grammar}\n\n')
sys.stderr.write(f'\n# GRAMMAR:\n\n{chat_handler.grammar}\n\n')
data = LlamaCppServerCompletionRequest(
**{
@ -94,9 +95,10 @@ def main(
)
},
prompt=prompt,
grammar=grammar,
grammar=chat_handler.grammar,
).model_dump()
sys.stderr.write(json.dumps(data, indent=2) + "\n")
# sys.stderr.write(json.dumps(data, indent=2) + "\n")
async with httpx.AsyncClient() as client:
response = await client.post(
f"{cpp_server_endpoint}/completions",
@ -116,7 +118,7 @@ def main(
return JSONResponse(result)
# print(json.dumps(result.get('content'), indent=2))
message = parser(result["content"])
message = chat_handler.parse(result["content"])
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
prompt_tokens=result['timings']['prompt_n']