diff --git a/examples/openai/api.py b/examples/openai/api.py
index dd8da09a2..7c4a446b8 100644
--- a/examples/openai/api.py
+++ b/examples/openai/api.py
@@ -18,7 +18,7 @@ class Message(BaseModel):
class ToolFunction(BaseModel):
name: str
description: str
- parameters: Any
+ parameters: dict[str, Any]
class Tool(BaseModel):
type: str
diff --git a/examples/openai/prompting.py b/examples/openai/prompting.py
index e26ca9229..60dab69ba 100644
--- a/examples/openai/prompting.py
+++ b/examples/openai/prompting.py
@@ -1,11 +1,14 @@
+from abc import ABC, abstractmethod
from enum import Enum
+from functools import wraps
import jinja2
import json
from pathlib import Path
import random
import re
import sys
-from typing import Optional, Tuple, Callable
+from typing import Any, Dict, Literal, Optional, Tuple, Callable, Union
+from pydantic import BaseModel
from typeguard import typechecked
from examples.json_schema_to_grammar import SchemaConverter
@@ -18,22 +21,52 @@ def raise_exception(msg: str):
raise Exception(msg)
@typechecked
-class ChatFormat:
- def __init__(self, template: str, eos_token: str, bos_token: str):
- env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
- self.template = env.from_string(template)
- self.eos_token = eos_token
- self.bos_token = bos_token
+class ChatTemplate(BaseModel):
+ template: str
- self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template
+ @property
+ def tool_style(self) -> 'ToolsPromptStyle':
+ return self._tool_style
+
+ def __init__(self, template: str, eos_token: str, bos_token: str):
+ super().__init__(template=template
+ )
+ env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
+ self._template = env.from_string(template)
+ self._eos_token = eos_token
+ self._bos_token = bos_token
+
+ self._strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template
if "<|recipient|>' + tool_call['function']['name']" in template:
- self.tool_style = ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2
+ self._tool_style = ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2
else:
- self.tool_style = ToolsPromptStyle.TOOLS_LONG
+ self._tool_style = ToolsPromptStyle.TOOLS_BESPOKE
+ # self._tool_style = ToolsPromptStyle.TOOLS_LONG
+
+ # TODO: Test whether the template supports formatting tool_calls
+
+ delimiter = '<%$[SAMPLE]$%>'
+ user_msg = Message(role="user", content="Hey")
+ empty_prompt = self.render([user_msg], add_generation_prompt=True).strip()
+ planted_prompt = self.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
+ assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
+ [prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
+
+ sys.stderr.write(f"\n# prefix={prefix}\n# suffix={suffix}\n\n")
+
+ self._prefix = prefix
+ self._suffix = suffix
+
+ def strip_suffix(self, s: str) -> str:
+ if s.endswith(self._suffix):
+ return s[:-len(self._suffix)]
+ else:
+ sys.stderr.write(f"Expected suffix ({self._suffix}) not found: {s}\n")
+ return s
def __str__(self):
- return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})"
+ return f"ChatTemplate(template={self.template}, eos_token={self._eos_token}, bos_token={self._bos_token})"
def add_system_prompt(self, messages: list[Message], system_prompt: Message) -> list[Message]:
assert system_prompt.role == "system"
@@ -48,13 +81,13 @@ class ChatFormat:
@staticmethod
def from_gguf(metadata: GGUFKeyValues):
tokens = metadata[Keys.Tokenizer.LIST]
- return ChatFormat(
+ return ChatTemplate(
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
bos_token = tokens[metadata[Keys.Tokenizer.BOS_ID]],
eos_token = tokens[metadata[Keys.Tokenizer.EOS_ID]])
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
- if self.strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
+ if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
new_messages=[]
i = 0
n = len(messages)
@@ -80,10 +113,10 @@ class ChatFormat:
messages = new_messages
# print(f'messages={messages}')
- result = self.template.render(
+ result = self._template.render(
messages=messages,
- eos_token=self.eos_token,
- bos_token='' if omit_bos else self.bos_token,
+ eos_token=self._eos_token,
+ bos_token='' if omit_bos else self._bos_token,
raise_exception=raise_exception,
add_generation_prompt=add_generation_prompt,
)
@@ -95,67 +128,176 @@ class ChatFormat:
# each model may need specific prompting (and/or constrained output,
# especially for models not fine-tuned for tool usage / function calling).
class ToolsPromptStyle(Enum):
- # Short prompt w/ schemas
+ # Short prompt w/ schemas, ... output
TOOLS_SHORT = 1
- # Longer prompt w/ schemas
+ # Longer prompt w/ schemas, ... output
TOOLS_LONG = 2
+ # Bespoke constrained output format that favours thought and reasoning
+ # while allowing unambiguous parsing of parallel tool calling.
+ TOOLS_BESPOKE = 3
+
# Large prompt for https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B
+ # ... output
# Requires:
# - git clone https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling
# - Set large context length as their prompts are super long
- TOOLS_HERMES_2_PRO = 3
+ TOOLS_HERMES_2_PRO = 4
+
+ # Seems to want to escape underscores in tool names and in the ... tags
+ TOOLS_MISTRAL = 5
# Short prompt w/ TypeScript definitions for https://github.com/MeetKai/functionary
# https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py
# Note: see this prior attempt to support Functionary: https://github.com/ggerganov/llama.cpp/pull/5695
- TYPESCRIPT_FUNCTIONARY_V2 = 4
+ TYPESCRIPT_FUNCTIONARY_V2 = 6
-@typechecked
-def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> Message:
+class ChatHandlerArgs(BaseModel):
+ chat_template: ChatTemplate
+ response_schema: Optional[dict] = None
+ tools: Optional[list[Tool]] = None
- if chat_format.tool_style == ToolsPromptStyle.TOOLS_SHORT:
- return Message(
+class ChatHandler(ABC):
+ def __init__(self, args: ChatHandlerArgs):
+ self.args = args
+ self.output_format_prompt: Optional[Message] = None
+ self.grammar: Optional[str] = None
+
+ @abstractmethod
+ def parse(self, s: str) -> Optional[Message]:
+ raise NotImplementedError()
+
+class NoToolsChatHandler(ChatHandler):
+ def __init__(self, args: ChatHandlerArgs):
+ super().__init__(args)
+ assert not args.tools
+
+ if args.response_schema:
+ self.output_format_prompt = Message(
+ role="system",
+ content=_please_respond_with_schema(args.response_schema)
+ )
+ converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
+ self.grammar = converter.visit(args.response_schema, '')
+ else:
+ self.output_format_prompt = None
+ self.grammar = None
+
+ @typechecked
+ def parse(self, s: str) -> Optional[Message]:
+ return Message(role="assistant", content=s)
+
+class ToolCallTagsChatHandler(ChatHandler):
+ def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, allow_parallel_calls: bool):
+ super().__init__(args)
+
+ converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
+ tool_rules = [
+ converter.visit(
+ dict(
+ type="object",
+ properties=dict(
+ name=dict(type="string", pattern='^' + tool.function.name.replace('_', f'\\?_') + '$') if escapes_underscores \
+ else dict(const=tool.function.name),
+ arguments=tool.function.parameters,
+ ),
+ required=['name', 'arguments']
+ ),
+ f'{tool.function.name}-tool-call'
+ )
+ for tool in self.args.tools
+ ]
+
+ def format_literal(s: str) -> str:
+ if escapes_underscores:
+ return ' "\\\\"? "_" '.join((converter._format_literal(part) for part in s.split('_')))
+ else:
+ return converter._format_literal(s)
+
+ tool_call_rule = converter._add_rule(
+ 'tool_call',
+ format_literal("") + " space (" +
+ ' | '.join(tool_rules) +
+ ") space " + format_literal(""))# + ' space')
+
+ # Ideally we'd want a negative lookahead of //, but it's just too hard to express in GBNF for now.
+ # So we just over-constrain the content rule to not contain literals dangerously getting close to
+ content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "'))
+ converter._add_rule(
+ 'root',
+ # tool_call_rule)
+ f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \
+ else f'{content_rule}* {tool_call_rule}?')
+ self.grammar = converter.format_grammar()
+
+ # # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
+ # # OR a tool-call message respecting the schema of any of the tools
+ # converter._add_rule(
+ # "root",
+ # converter._format_literal(prefix) + " (" +
+ # (response_rule or converter.not_literal("")) + " | " +
+ # converter._format_literal("") + " (" +
+ # ' | '.join(tool_rules) +
+ # ") " + converter._format_literal("") +
+ # ")") # + converter._format_literal(suffix))
+
+ @typechecked
+ def parse(self, s: str) -> Optional[Message]:
+ s = self.args.chat_template.strip_suffix(s)
+
+ if r'' in s:
+ # Some weird escaping of underscores is happening w/ Mixtral 8x7B Instruct
+ s = s.replace(r'\_', '_')
+
+ parts = _tool_call_re.split(s)
+ if len(parts) == 1:
+ return Message(role="assistant", content=s)
+ else:
+ content = []
+ tool_calls = []
+ for i, part in enumerate(parts):
+ if i % 2 == 0:
+ content.append(part)
+ else:
+ try:
+ fc = json.loads(part)
+ except json.JSONDecodeError:
+ raise ValueError(f'Failed to parse tool call as JSON: {part}\nFull string: {s}')
+ tool_calls.append(
+ ToolCall(
+ id=gen_callid(),
+ function=FunctionCall(**fc)))
+
+ content = '\n'.join(content).strip()
+ return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
+
+ # if ''.startswith(ls) or ls.startswith(''):
+ # if ls.startswith('') and ls.endswith('' + suffix):
+ # tool_call = ls[len(''):-len('' + suffix)]
+ # return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
+ # return None
+ # else:
+ # return Message(role="assistant", content=s)
+
+class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
+ def __init__(self, args: ChatHandlerArgs, template: str, escapes_underscores=False, allow_parallel_calls=True):
+ super().__init__(args, escapes_underscores=escapes_underscores, allow_parallel_calls=allow_parallel_calls)
+ assert '{tools}' in template, 'Template must contain "{tools}"'
+
+ self.output_format_prompt = Message(
role="system",
- content='\n'.join([
- 'Here are the tools available:',
- '',
- *(json.dumps(tool.model_dump(), indent=indent) for tool in tools),
- '',
- ])
+ content=template.replace(
+ '{tools}',
+ '\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in self.args.tools),
+ )
)
-
- elif chat_format.tool_style == ToolsPromptStyle.TOOLS_LONG:
- return Message(
- role="system",
- content='\n'.join([
- # '''You are a function calling AI model. You are provided with function signatures within XML tags.''',
- '''You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:''',
- '''''',
- _tools_typescript_signatures(tools),
- # _tools_schema_signatures(tools, indent=indent),
- '''''',
- '',
- # '''Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}''',
- # '',
- # '''For each function call return a json object with function name and arguments within XML tags as follows:''',
- '''To call each function, give its name and arguments within XML tags as follows:''',
- '''''',
- '''{"name": , "arguments": }''',
- '''''',
- # '''This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with ....''',
- ])
- )
-
- elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
- return Message(
- role="system",
- content= '// Supported function definitions that should be called when necessary.\n' +
- _tools_typescript_signatures(tools)
- )
-
- elif chat_format.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
+
+class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
+ def __init__(self, args: ChatHandlerArgs):
+ super().__init__(args, escapes_underscores=False, allow_parallel_calls=False)
+
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
path = str(Path(__file__).parent / "hermes_function_calling")
if path not in sys.path: sys.path.insert(0, path)
@@ -166,16 +308,276 @@ def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> M
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools])
assert len(prompt) == 1 and prompt[0]["role"] == "system"
- return Message(**prompt[0])
+ self.output_format_prompt = Message(**prompt[0])
+
+class FunctionaryToolsChatHandler(ChatHandler):
+ def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
+ super().__init__(args)
+
+ # Only allowing a single tool call at a time for now.
+ # Note that if there were more, they'd be separated by a '<|from|>assistant' literal
+
+ self.output_format_prompt = Message(
+ role="system",
+ content= '// Supported function definitions that should be called when necessary.\n' +
+ _tools_typescript_signatures(args.tools)
+ )
+ converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
+ tool_rules = [
+ converter._add_rule(
+ tool.function.name + '-call',
+ converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
+ converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
+ converter._format_literal('\n'))
+ # converter.visit(
+ # dict(
+ # type="object",
+ # properties=dict(
+ # name=dict(const=tool.function.name),
+ # arguments=tool.function.parameters,
+ # ),
+ # required=['name', 'arguments']
+ # ),
+ # f'{tool.function.name}-tool-call'
+ # )
+ for i, tool in enumerate(self.args.tools)
+ ]
+
+ not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
+ content_without_start_rule = converter._add_rule(
+ 'content_without_start',
+ converter._format_literal("all\n<|content|>") + ' ' + not_from_rule + '*')
+ start_rule = converter._add_rule('start', converter._format_literal('<|from|>assistant\n<|recipient|>'))
+ content_rule = converter._add_rule('content', start_rule + ' ' + content_without_start_rule)
+ tool_call_without_start_rule = converter._add_rule(
+ 'tool_call_without_start',
+ ' | '.join(tool_rules))
+ # + ' ' +
+ # converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
+ tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
+ # converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
+ converter._add_rule(
+ 'root',
+ f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
+ f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \
+ else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
+
+ self.grammar = converter.format_grammar()
+ # converter._add_rule(
+ # "root",
+ # converter._format_literal(prefix) + " (" +
+ # (response_rule or converter.not_literal("<|recipient|>")) + " | " +
+ # (' | '.join(
+ # converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
+ # converter.visit(tool.function.parameters, tool.function.name + '-args')
+ # for tool in tools
+ # )) +
+ # ") " +
+ # ")") # + converter._format_literal(suffix))
+
+ @typechecked
+ def parse(self, s: str) -> Optional[Message]:
+ s = self.args.chat_template.strip_suffix(s)
+
+ parts = _recipient_content_re.split(s)
+ if len(parts) == 1:
+ return Message(role="assistant", content=s)
+ else:
+ text_content = []
+ tool_calls: list[ToolCall] = []
+ for i in range((len(parts) - 1) // 3):
+ assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}'
+ recipient = parts[i * 3 + 1].strip()
+ content = parts[i * 3 + 2]
+ if recipient == 'all':
+ text_content.append(content)
+ else:
+ try:
+ arguments = json.loads(content)
+ except json.JSONDecodeError:
+ raise ValueError(f'Failed to parse tool call content as JSON: {content}')
+ tool_calls.append(
+ ToolCall(
+ id=gen_callid(),
+ function=FunctionCall(name=recipient, arguments=arguments)))
+
+
+ assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}'
+
+ content = '\n'.join(text_content).strip()
+ return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
+
+def _make_bespoke_schema(response_schema, tool_call_schema):
+ return {
+ "type": "object",
+ "properties": {
+ # "original_goal": {"title": "Original Goal", "type": "string"},
+ "thought": {
+ # "title": "Thought about how the next step brings us closer to achieving the original goal",
+ "type": "string"
+ },
+ "next_step": {
+ "title": "Next Step: either a result or one or more tool calls to achieve the original goal",
+ "oneOf": [
+ {
+ "title": "Tool Calls",
+ "properties": {
+ # "type": {
+ # "const": "tool_calls"
+ # },
+ "tool_calls": {
+ "type": "array",
+ "items": tool_call_schema
+ }
+ },
+ "required": ["tool_calls"]
+ },
+ {
+ "title": "Result (achieving original goal)",
+ "properties": {
+ "result": response_schema,
+ },
+ "required": ["result"]
+ },
+ ]
+ },
+ },
+ "required": ["original_goal", "thought", "next_step"]
+ }
+
+class BespokeToolsChatHandler(ChatHandler):
+ def __init__(self, args: ChatHandlerArgs):
+ super().__init__(args)
+
+ # args.response_schema = args.response_schema or {}
+ converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
+
+ response_schema = args.response_schema or {"type": "string"}
+ converter.visit(
+ _make_bespoke_schema(
+ response_schema,
+ {
+ "oneOf": [
+ {
+ "type": "object",
+ "properties": {
+ "name": {"const": tool.function.name},
+ "arguments": tool.function.parameters,
+ },
+ "required": ["name", "arguments"]
+ }
+ for tool in self.args.tools
+ ]
+ }
+ ),
+ '',
+ )
+ self.grammar = converter.format_grammar()
+
+ self.output_format_prompt = Message(
+ role="system",
+ content='\n'.join([
+ 'You are a function calling AI model.',
+ 'Here are the tools available:',
+ _tools_schema_signatures(self.args.tools, indent=2),
+ _please_respond_with_schema(
+ _make_bespoke_schema(
+ response_schema,
+ {
+ "properties": {
+ "name": {
+ "title": "Name of the tool to call",
+ "type": "string"
+ },
+ "arguments": {
+ "title": "Arguments to pass to the tool",
+ "type": "object"
+ }
+ },
+ "required": ["name", "arguments"]
+ }
+ )
+ ),
+ ])
+ )
+
+ @typechecked
+ def parse(self, s: str) -> Optional[Message]:
+ s = self.args.chat_template.strip_suffix(s)
+ try:
+ data = json.loads(s)
+ except json.JSONDecodeError:
+ raise ValueError(f'Failed to parse data as JSON: {s}')
+
+ next_step = data['next_step']
+ if 'result' in next_step:
+ return Message(role="assistant", content=json.dumps(next_step['result']))
+ elif 'tool_calls' in next_step:
+ return Message(
+ role="assistant",
+ content=data["thought"],
+ tool_calls=[
+ ToolCall(id=gen_callid(), function=FunctionCall(**tc))
+ for tc in next_step['tool_calls']
+ ]
+ )
+ else:
+ raise ValueError(f'Unexpected data: {data}')
+
+_SHORT_TEMPLATE='\n'.join([
+ 'Here are the tools available:',
+ '',
+ '{tools}',
+ '',
+])
+
+_LONG_TEMPLATE='\n'.join([
+ # '''You are a function calling AI model. You are provided with function signatures within XML tags.''',
+ 'You may call one or more functions to assist with the user query. Don\'t make assumptions about what values to plug into functions. Here are the available tools:',
+ '',
+ '{tools}',
+ '',
+ '',
+ # 'Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}',
+ # '',
+ # 'For each function call return a json object with function name and arguments within XML tags as follows:',
+ 'To call each function, give its name and arguments within XML tags as follows:',
+ '',
+ '{"name": , "arguments": }',
+ '',
+ # 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with ....''',
+])
+
+def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatHandler:
+ if not args.tools:
+ return NoToolsChatHandler(args)
+ elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
+ return FunctionaryToolsChatHandler(args)
+ elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT:
+ return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
+ elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG:
+ return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
+ elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
+ return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
+ elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
+ return BespokeToolsChatHandler(args)
+ elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
+ return Hermes2ProToolsChatHandler(args)
else:
- raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}")
-
+ raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
+
+_ts_converter = SchemaToTypeScriptConverter()
+
+def _please_respond_with_schema(schema: dict) -> str:
+ # sig = json.dumps(schema, indent=2)
+ sig = _ts_converter.visit(schema)
+ return f'Please respond in JSON format with the following schema: {sig}'
+
def _tools_typescript_signatures(tools: list[Tool]) -> str:
- ts_converter = SchemaToTypeScriptConverter()
return 'namespace functions {' + '\n'.join(
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
- 'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n"
+ 'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"
for tool in tools
) + '} // namespace functions'
@@ -185,247 +587,9 @@ def _tools_schema_signatures(tools: list[Tool], indent=None) -> str:
for tool in tools
)
-@typechecked
-def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
- return style in (
- ToolsPromptStyle.TOOLS_SHORT,
- ToolsPromptStyle.TOOLS_LONG,
- ToolsPromptStyle.TOOLS_HERMES_2_PRO,
- )
-
_tool_call_re = re.compile(
'(.*?)', re.DOTALL)
_recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL)
def gen_callid():
return f'call_{random.randint(0, 1000000)}'
-
-@typechecked
-def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[list[Message]]]]:
-
- converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
-
- response_rule = converter.visit(response_schema, "response") if response_schema else None
-
- delimiter = '<%$[SAMPLE]$%>'
- user_msg = Message(role="user", content="Hey")
- empty_prompt = chat_format.render([user_msg], add_generation_prompt=True).strip()
- planted_prompt = chat_format.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
- assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
- [prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
-
- allow_parallel_calls = False
-
- def strip_suffix(s: str) -> str:
- if s.endswith(suffix):
- return s[:-len(suffix)]
- else:
- sys.stderr.write(f"Expected suffix ({suffix}) not found: {s}\n")
- return s
-
- if tools:
- if _outputs_tool_call_tags(chat_format.tool_style):
-
- escapes_underscores = chat_format.tool_style != ToolsPromptStyle.TOOLS_HERMES_2_PRO
-
- tool_rules = [
- converter.visit(
- dict(
- type="object",
- properties=dict(
- name=dict(type="string", pattern='^' + tool.function.name.replace('_', f'\\?_') + '$') if escapes_underscores \
- else dict(const=tool.function.name),
- arguments=tool.function.parameters,
- ),
- required=['name', 'arguments']
- ),
- f'{tool.function.name}-tool-call'
- )
- for tool in tools
- ]
-
- def format_literal(s: str) -> str:
- if escapes_underscores:
- return ' "\\\\"? "_" '.join((converter._format_literal(part) for part in s.split('_')))
- else:
- return converter._format_literal(s)
-
- tool_call_rule = converter._add_rule(
- 'tool_call',
- format_literal("") + " space (" +
- ' | '.join(tool_rules) +
- ") space " + format_literal(""))# + ' space')
-
- # Ideally we'd want a negative lookahead of //, but it's just too hard to express in GBNF for now.
- # So we just over-constrain the content rule to not contain literals dangerously getting close to
- content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "'))
- converter._add_rule(
- 'root',
- # tool_call_rule)
- f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \
- else f'{content_rule}* {tool_call_rule}?')
-
- # # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
- # # OR a tool-call message respecting the schema of any of the tools
- # converter._add_rule(
- # "root",
- # converter._format_literal(prefix) + " (" +
- # (response_rule or converter.not_literal("")) + " | " +
- # converter._format_literal("") + " (" +
- # ' | '.join(tool_rules) +
- # ") " + converter._format_literal("") +
- # ")") # + converter._format_literal(suffix))
-
- @typechecked
- def parse(s: str) -> Optional[Message]:
- s = strip_suffix(s)
-
- if r'' in s:
- # Some weird escaping of underscores is happening w/ Mixtral 8x7B Instruct
- s = s.replace(r'\_', '_')
-
- parts = _tool_call_re.split(s)
- if len(parts) == 1:
- return Message(role="assistant", content=s)
- else:
- content = []
- tool_calls = []
- for i, part in enumerate(parts):
- if i % 2 == 0:
- content.append(part)
- else:
- try:
- fc = json.loads(part)
- except json.JSONDecodeError:
- raise ValueError(f'Failed to parse tool call as JSON: {part}\nFull string: {s}')
- tool_calls.append(
- ToolCall(
- id=gen_callid(),
- function=FunctionCall(**fc)))
-
- content = '\n'.join(content).strip()
- return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
-
- # if ''.startswith(ls) or ls.startswith(''):
- # if ls.startswith('') and ls.endswith('' + suffix):
- # tool_call = ls[len(''):-len('' + suffix)]
- # return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
- # return None
- # else:
- # return Message(role="assistant", content=s)
-
- return (converter.format_grammar(), parse)
-
- elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
- # Only allowing a single tool call at a time for now.
- # Note that if there were more, they'd be separated by a '<|from|>assistant' literal
-
- tool_rules = [
- converter._add_rule(
- tool.function.name + '-call',
- converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
- converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
- converter._format_literal('\n'))
- # converter.visit(
- # dict(
- # type="object",
- # properties=dict(
- # name=dict(const=tool.function.name),
- # arguments=tool.function.parameters,
- # ),
- # required=['name', 'arguments']
- # ),
- # f'{tool.function.name}-tool-call'
- # )
- for i, tool in enumerate(tools)
- ]
-
- not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
- content_without_start_rule = converter._add_rule(
- 'content_without_start',
- converter._format_literal("all\n<|content|>") + ' ' + not_from_rule + '*')
- start_rule = converter._add_rule('start', converter._format_literal('<|from|>assistant\n<|recipient|>'))
- content_rule = converter._add_rule('content', start_rule + ' ' + content_without_start_rule)
- tool_call_without_start_rule = converter._add_rule(
- 'tool_call_without_start',
- ' | '.join(tool_rules))
- # + ' ' +
- # converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
- tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
- # converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
- converter._add_rule(
- 'root',
- f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
- f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \
- else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
-
- # converter._add_rule(
- # "root",
- # converter._format_literal(prefix) + " (" +
- # (response_rule or converter.not_literal("<|recipient|>")) + " | " +
- # (' | '.join(
- # converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
- # converter.visit(tool.function.parameters, tool.function.name + '-args')
- # for tool in tools
- # )) +
- # ") " +
- # ")") # + converter._format_literal(suffix))
-
- @typechecked
- def parse(s: str) -> Optional[Message]:
- s = strip_suffix(s)
-
- parts = _recipient_content_re.split(s)
- if len(parts) == 1:
- return Message(role="assistant", content=s)
- else:
- text_content = []
- tool_calls: list[ToolCall] = []
- for i in range((len(parts) - 1) // 3):
- assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}'
- recipient = parts[i * 3 + 1].strip()
- content = parts[i * 3 + 2]
- if recipient == 'all':
- text_content.append(content)
- else:
- try:
- arguments = json.loads(content)
- except json.JSONDecodeError:
- raise ValueError(f'Failed to parse tool call content as JSON: {content}')
- tool_calls.append(
- ToolCall(
- id=gen_callid(),
- function=FunctionCall(name=recipient, arguments=arguments)))
-
-
- assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}'
-
- content = '\n'.join(text_content).strip()
- return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
-
- return (converter.format_grammar(), parse)
-
- else:
- raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}")
-
- elif response_schema:
- converter._add_rule("root", response_rule + ' ' + converter._format_literal(suffix))
-
- @typechecked
- def parse(s: str) -> Optional[Message]:
- s = strip_suffix(s)
- return Message(role="assistant", content=s)
-
- return (converter.format_grammar(), parse)
-
- else:
- converter._add_rule("root", converter._format_literal(prefix) + ' ' + converter._format_literal(suffix))
-
- @typechecked
- def parse(s: str) -> Optional[Message]:
- s = strip_suffix(s)
- return Message(role="assistant", content=s)
-
- return (None, parse)
-
diff --git a/examples/openai/server.py b/examples/openai/server.py
index ad3910625..fbd2f22da 100644
--- a/examples/openai/server.py
+++ b/examples/openai/server.py
@@ -5,12 +5,14 @@ import json, sys, subprocess, atexit
from pathlib import Path
import time
+from pydantic import TypeAdapter
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage
-from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt
+from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, get_chat_handler, ChatHandler
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
@@ -38,8 +40,8 @@ def main(
metadata = GGUFKeyValues(model)
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
- chat_format = ChatFormat.from_gguf(metadata)
- # print(chat_format)
+ chat_template = ChatTemplate.from_gguf(metadata)
+ # print(chat_template)
if not cpp_server_endpoint:
sys.stderr.write(f"# Starting C++ server with model {model} on {cpp_server_host}:{cpp_server_port}\n")
@@ -69,18 +71,17 @@ def main(
else:
response_schema = None
- messages = chat_request.messages
- if chat_request.tools:
- messages = chat_format.add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
-
- (grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
-
- # TODO: Test whether the template supports formatting tool_calls
-
- prompt = chat_format.render(messages, add_generation_prompt=True)
+ chat_handler = get_chat_handler(ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools))
+ messages = chat_request.messages
+ if chat_handler.output_format_prompt:
+ messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
+
+ prompt = chat_template.render(messages, add_generation_prompt=True)
+
+ sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n')
sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n')
- sys.stderr.write(f'\n# GRAMMAR:\n\n{grammar}\n\n')
+ sys.stderr.write(f'\n# GRAMMAR:\n\n{chat_handler.grammar}\n\n')
data = LlamaCppServerCompletionRequest(
**{
@@ -94,9 +95,10 @@ def main(
)
},
prompt=prompt,
- grammar=grammar,
+ grammar=chat_handler.grammar,
).model_dump()
- sys.stderr.write(json.dumps(data, indent=2) + "\n")
+ # sys.stderr.write(json.dumps(data, indent=2) + "\n")
+
async with httpx.AsyncClient() as client:
response = await client.post(
f"{cpp_server_endpoint}/completions",
@@ -116,7 +118,7 @@ def main(
return JSONResponse(result)
# print(json.dumps(result.get('content'), indent=2))
- message = parser(result["content"])
+ message = chat_handler.parse(result["content"])
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
prompt_tokens=result['timings']['prompt_n']