diff --git a/examples/openai/api.py b/examples/openai/api.py index dd8da09a2..7c4a446b8 100644 --- a/examples/openai/api.py +++ b/examples/openai/api.py @@ -18,7 +18,7 @@ class Message(BaseModel): class ToolFunction(BaseModel): name: str description: str - parameters: Any + parameters: dict[str, Any] class Tool(BaseModel): type: str diff --git a/examples/openai/prompting.py b/examples/openai/prompting.py index e26ca9229..60dab69ba 100644 --- a/examples/openai/prompting.py +++ b/examples/openai/prompting.py @@ -1,11 +1,14 @@ +from abc import ABC, abstractmethod from enum import Enum +from functools import wraps import jinja2 import json from pathlib import Path import random import re import sys -from typing import Optional, Tuple, Callable +from typing import Any, Dict, Literal, Optional, Tuple, Callable, Union +from pydantic import BaseModel from typeguard import typechecked from examples.json_schema_to_grammar import SchemaConverter @@ -18,22 +21,52 @@ def raise_exception(msg: str): raise Exception(msg) @typechecked -class ChatFormat: - def __init__(self, template: str, eos_token: str, bos_token: str): - env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True) - self.template = env.from_string(template) - self.eos_token = eos_token - self.bos_token = bos_token +class ChatTemplate(BaseModel): + template: str - self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template + @property + def tool_style(self) -> 'ToolsPromptStyle': + return self._tool_style + + def __init__(self, template: str, eos_token: str, bos_token: str): + super().__init__(template=template + ) + env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True) + self._template = env.from_string(template) + self._eos_token = eos_token + self._bos_token = bos_token + + self._strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template if "<|recipient|>' + tool_call['function']['name']" in template: - self.tool_style = ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2 + self._tool_style = ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2 else: - self.tool_style = ToolsPromptStyle.TOOLS_LONG + self._tool_style = ToolsPromptStyle.TOOLS_BESPOKE + # self._tool_style = ToolsPromptStyle.TOOLS_LONG + + # TODO: Test whether the template supports formatting tool_calls + + delimiter = '<%$[SAMPLE]$%>' + user_msg = Message(role="user", content="Hey") + empty_prompt = self.render([user_msg], add_generation_prompt=True).strip() + planted_prompt = self.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip() + assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}" + [prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter) + + sys.stderr.write(f"\n# prefix={prefix}\n# suffix={suffix}\n\n") + + self._prefix = prefix + self._suffix = suffix + + def strip_suffix(self, s: str) -> str: + if s.endswith(self._suffix): + return s[:-len(self._suffix)] + else: + sys.stderr.write(f"Expected suffix ({self._suffix}) not found: {s}\n") + return s def __str__(self): - return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})" + return f"ChatTemplate(template={self.template}, eos_token={self._eos_token}, bos_token={self._bos_token})" def add_system_prompt(self, messages: list[Message], system_prompt: Message) -> list[Message]: assert system_prompt.role == "system" @@ -48,13 +81,13 @@ class ChatFormat: @staticmethod def from_gguf(metadata: GGUFKeyValues): tokens = metadata[Keys.Tokenizer.LIST] - return ChatFormat( + return ChatTemplate( template = metadata[Keys.Tokenizer.CHAT_TEMPLATE], bos_token = tokens[metadata[Keys.Tokenizer.BOS_ID]], eos_token = tokens[metadata[Keys.Tokenizer.EOS_ID]]) def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False): - if self.strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages): + if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages): new_messages=[] i = 0 n = len(messages) @@ -80,10 +113,10 @@ class ChatFormat: messages = new_messages # print(f'messages={messages}') - result = self.template.render( + result = self._template.render( messages=messages, - eos_token=self.eos_token, - bos_token='' if omit_bos else self.bos_token, + eos_token=self._eos_token, + bos_token='' if omit_bos else self._bos_token, raise_exception=raise_exception, add_generation_prompt=add_generation_prompt, ) @@ -95,67 +128,176 @@ class ChatFormat: # each model may need specific prompting (and/or constrained output, # especially for models not fine-tuned for tool usage / function calling). class ToolsPromptStyle(Enum): - # Short prompt w/ schemas + # Short prompt w/ schemas, ... output TOOLS_SHORT = 1 - # Longer prompt w/ schemas + # Longer prompt w/ schemas, ... output TOOLS_LONG = 2 + # Bespoke constrained output format that favours thought and reasoning + # while allowing unambiguous parsing of parallel tool calling. + TOOLS_BESPOKE = 3 + # Large prompt for https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B + # ... output # Requires: # - git clone https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling # - Set large context length as their prompts are super long - TOOLS_HERMES_2_PRO = 3 + TOOLS_HERMES_2_PRO = 4 + + # Seems to want to escape underscores in tool names and in the ... tags + TOOLS_MISTRAL = 5 # Short prompt w/ TypeScript definitions for https://github.com/MeetKai/functionary # https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py # Note: see this prior attempt to support Functionary: https://github.com/ggerganov/llama.cpp/pull/5695 - TYPESCRIPT_FUNCTIONARY_V2 = 4 + TYPESCRIPT_FUNCTIONARY_V2 = 6 -@typechecked -def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> Message: +class ChatHandlerArgs(BaseModel): + chat_template: ChatTemplate + response_schema: Optional[dict] = None + tools: Optional[list[Tool]] = None - if chat_format.tool_style == ToolsPromptStyle.TOOLS_SHORT: - return Message( +class ChatHandler(ABC): + def __init__(self, args: ChatHandlerArgs): + self.args = args + self.output_format_prompt: Optional[Message] = None + self.grammar: Optional[str] = None + + @abstractmethod + def parse(self, s: str) -> Optional[Message]: + raise NotImplementedError() + +class NoToolsChatHandler(ChatHandler): + def __init__(self, args: ChatHandlerArgs): + super().__init__(args) + assert not args.tools + + if args.response_schema: + self.output_format_prompt = Message( + role="system", + content=_please_respond_with_schema(args.response_schema) + ) + converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) + self.grammar = converter.visit(args.response_schema, '') + else: + self.output_format_prompt = None + self.grammar = None + + @typechecked + def parse(self, s: str) -> Optional[Message]: + return Message(role="assistant", content=s) + +class ToolCallTagsChatHandler(ChatHandler): + def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, allow_parallel_calls: bool): + super().__init__(args) + + converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) + tool_rules = [ + converter.visit( + dict( + type="object", + properties=dict( + name=dict(type="string", pattern='^' + tool.function.name.replace('_', f'\\?_') + '$') if escapes_underscores \ + else dict(const=tool.function.name), + arguments=tool.function.parameters, + ), + required=['name', 'arguments'] + ), + f'{tool.function.name}-tool-call' + ) + for tool in self.args.tools + ] + + def format_literal(s: str) -> str: + if escapes_underscores: + return ' "\\\\"? "_" '.join((converter._format_literal(part) for part in s.split('_'))) + else: + return converter._format_literal(s) + + tool_call_rule = converter._add_rule( + 'tool_call', + format_literal("") + " space (" + + ' | '.join(tool_rules) + + ") space " + format_literal(""))# + ' space') + + # Ideally we'd want a negative lookahead of //, but it's just too hard to express in GBNF for now. + # So we just over-constrain the content rule to not contain literals dangerously getting close to + content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "')) + converter._add_rule( + 'root', + # tool_call_rule) + f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \ + else f'{content_rule}* {tool_call_rule}?') + self.grammar = converter.format_grammar() + + # # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not) + # # OR a tool-call message respecting the schema of any of the tools + # converter._add_rule( + # "root", + # converter._format_literal(prefix) + " (" + + # (response_rule or converter.not_literal("")) + " | " + + # converter._format_literal("") + " (" + + # ' | '.join(tool_rules) + + # ") " + converter._format_literal("") + + # ")") # + converter._format_literal(suffix)) + + @typechecked + def parse(self, s: str) -> Optional[Message]: + s = self.args.chat_template.strip_suffix(s) + + if r'' in s: + # Some weird escaping of underscores is happening w/ Mixtral 8x7B Instruct + s = s.replace(r'\_', '_') + + parts = _tool_call_re.split(s) + if len(parts) == 1: + return Message(role="assistant", content=s) + else: + content = [] + tool_calls = [] + for i, part in enumerate(parts): + if i % 2 == 0: + content.append(part) + else: + try: + fc = json.loads(part) + except json.JSONDecodeError: + raise ValueError(f'Failed to parse tool call as JSON: {part}\nFull string: {s}') + tool_calls.append( + ToolCall( + id=gen_callid(), + function=FunctionCall(**fc))) + + content = '\n'.join(content).strip() + return Message(role="assistant", content=content if content else None, tool_calls=tool_calls) + + # if ''.startswith(ls) or ls.startswith(''): + # if ls.startswith('') and ls.endswith('' + suffix): + # tool_call = ls[len(''):-len('' + suffix)] + # return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)]) + # return None + # else: + # return Message(role="assistant", content=s) + +class TemplatedToolsChatHandler(ToolCallTagsChatHandler): + def __init__(self, args: ChatHandlerArgs, template: str, escapes_underscores=False, allow_parallel_calls=True): + super().__init__(args, escapes_underscores=escapes_underscores, allow_parallel_calls=allow_parallel_calls) + assert '{tools}' in template, 'Template must contain "{tools}"' + + self.output_format_prompt = Message( role="system", - content='\n'.join([ - 'Here are the tools available:', - '', - *(json.dumps(tool.model_dump(), indent=indent) for tool in tools), - '', - ]) + content=template.replace( + '{tools}', + '\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in self.args.tools), + ) ) - - elif chat_format.tool_style == ToolsPromptStyle.TOOLS_LONG: - return Message( - role="system", - content='\n'.join([ - # '''You are a function calling AI model. You are provided with function signatures within XML tags.''', - '''You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:''', - '''''', - _tools_typescript_signatures(tools), - # _tools_schema_signatures(tools, indent=indent), - '''''', - '', - # '''Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}''', - # '', - # '''For each function call return a json object with function name and arguments within XML tags as follows:''', - '''To call each function, give its name and arguments within XML tags as follows:''', - '''''', - '''{"name": , "arguments": }''', - '''''', - # '''This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with ....''', - ]) - ) - - elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2: - return Message( - role="system", - content= '// Supported function definitions that should be called when necessary.\n' + - _tools_typescript_signatures(tools) - ) - - elif chat_format.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO: + +class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler): + def __init__(self, args: ChatHandlerArgs): + super().__init__(args, escapes_underscores=False, allow_parallel_calls=False) + # Hackily import https://github.com/NousResearch/Hermes-Function-Calling path = str(Path(__file__).parent / "hermes_function_calling") if path not in sys.path: sys.path.insert(0, path) @@ -166,16 +308,276 @@ def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> M prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools]) assert len(prompt) == 1 and prompt[0]["role"] == "system" - return Message(**prompt[0]) + self.output_format_prompt = Message(**prompt[0]) + +class FunctionaryToolsChatHandler(ChatHandler): + def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool): + super().__init__(args) + + # Only allowing a single tool call at a time for now. + # Note that if there were more, they'd be separated by a '<|from|>assistant' literal + + self.output_format_prompt = Message( + role="system", + content= '// Supported function definitions that should be called when necessary.\n' + + _tools_typescript_signatures(args.tools) + ) + converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) + tool_rules = [ + converter._add_rule( + tool.function.name + '-call', + converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + + converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' + + converter._format_literal('\n')) + # converter.visit( + # dict( + # type="object", + # properties=dict( + # name=dict(const=tool.function.name), + # arguments=tool.function.parameters, + # ), + # required=['name', 'arguments'] + # ), + # f'{tool.function.name}-tool-call' + # ) + for i, tool in enumerate(self.args.tools) + ] + + not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>")) + content_without_start_rule = converter._add_rule( + 'content_without_start', + converter._format_literal("all\n<|content|>") + ' ' + not_from_rule + '*') + start_rule = converter._add_rule('start', converter._format_literal('<|from|>assistant\n<|recipient|>')) + content_rule = converter._add_rule('content', start_rule + ' ' + content_without_start_rule) + tool_call_without_start_rule = converter._add_rule( + 'tool_call_without_start', + ' | '.join(tool_rules)) + # + ' ' + + # converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*') + tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}') + # converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*') + converter._add_rule( + 'root', + f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | ' + f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \ + else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}') + + self.grammar = converter.format_grammar() + # converter._add_rule( + # "root", + # converter._format_literal(prefix) + " (" + + # (response_rule or converter.not_literal("<|recipient|>")) + " | " + + # (' | '.join( + # converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " + + # converter.visit(tool.function.parameters, tool.function.name + '-args') + # for tool in tools + # )) + + # ") " + + # ")") # + converter._format_literal(suffix)) + + @typechecked + def parse(self, s: str) -> Optional[Message]: + s = self.args.chat_template.strip_suffix(s) + + parts = _recipient_content_re.split(s) + if len(parts) == 1: + return Message(role="assistant", content=s) + else: + text_content = [] + tool_calls: list[ToolCall] = [] + for i in range((len(parts) - 1) // 3): + assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}' + recipient = parts[i * 3 + 1].strip() + content = parts[i * 3 + 2] + if recipient == 'all': + text_content.append(content) + else: + try: + arguments = json.loads(content) + except json.JSONDecodeError: + raise ValueError(f'Failed to parse tool call content as JSON: {content}') + tool_calls.append( + ToolCall( + id=gen_callid(), + function=FunctionCall(name=recipient, arguments=arguments))) + + + assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}' + + content = '\n'.join(text_content).strip() + return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None) + +def _make_bespoke_schema(response_schema, tool_call_schema): + return { + "type": "object", + "properties": { + # "original_goal": {"title": "Original Goal", "type": "string"}, + "thought": { + # "title": "Thought about how the next step brings us closer to achieving the original goal", + "type": "string" + }, + "next_step": { + "title": "Next Step: either a result or one or more tool calls to achieve the original goal", + "oneOf": [ + { + "title": "Tool Calls", + "properties": { + # "type": { + # "const": "tool_calls" + # }, + "tool_calls": { + "type": "array", + "items": tool_call_schema + } + }, + "required": ["tool_calls"] + }, + { + "title": "Result (achieving original goal)", + "properties": { + "result": response_schema, + }, + "required": ["result"] + }, + ] + }, + }, + "required": ["original_goal", "thought", "next_step"] + } + +class BespokeToolsChatHandler(ChatHandler): + def __init__(self, args: ChatHandlerArgs): + super().__init__(args) + + # args.response_schema = args.response_schema or {} + converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) + + response_schema = args.response_schema or {"type": "string"} + converter.visit( + _make_bespoke_schema( + response_schema, + { + "oneOf": [ + { + "type": "object", + "properties": { + "name": {"const": tool.function.name}, + "arguments": tool.function.parameters, + }, + "required": ["name", "arguments"] + } + for tool in self.args.tools + ] + } + ), + '', + ) + self.grammar = converter.format_grammar() + + self.output_format_prompt = Message( + role="system", + content='\n'.join([ + 'You are a function calling AI model.', + 'Here are the tools available:', + _tools_schema_signatures(self.args.tools, indent=2), + _please_respond_with_schema( + _make_bespoke_schema( + response_schema, + { + "properties": { + "name": { + "title": "Name of the tool to call", + "type": "string" + }, + "arguments": { + "title": "Arguments to pass to the tool", + "type": "object" + } + }, + "required": ["name", "arguments"] + } + ) + ), + ]) + ) + + @typechecked + def parse(self, s: str) -> Optional[Message]: + s = self.args.chat_template.strip_suffix(s) + try: + data = json.loads(s) + except json.JSONDecodeError: + raise ValueError(f'Failed to parse data as JSON: {s}') + + next_step = data['next_step'] + if 'result' in next_step: + return Message(role="assistant", content=json.dumps(next_step['result'])) + elif 'tool_calls' in next_step: + return Message( + role="assistant", + content=data["thought"], + tool_calls=[ + ToolCall(id=gen_callid(), function=FunctionCall(**tc)) + for tc in next_step['tool_calls'] + ] + ) + else: + raise ValueError(f'Unexpected data: {data}') + +_SHORT_TEMPLATE='\n'.join([ + 'Here are the tools available:', + '', + '{tools}', + '', +]) + +_LONG_TEMPLATE='\n'.join([ + # '''You are a function calling AI model. You are provided with function signatures within XML tags.''', + 'You may call one or more functions to assist with the user query. Don\'t make assumptions about what values to plug into functions. Here are the available tools:', + '', + '{tools}', + '', + '', + # 'Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}', + # '', + # 'For each function call return a json object with function name and arguments within XML tags as follows:', + 'To call each function, give its name and arguments within XML tags as follows:', + '', + '{"name": , "arguments": }', + '', + # 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with ....''', +]) + +def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatHandler: + if not args.tools: + return NoToolsChatHandler(args) + elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2: + return FunctionaryToolsChatHandler(args) + elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT: + return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, allow_parallel_calls=allow_parallel_calls) + elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG: + return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, allow_parallel_calls=allow_parallel_calls) + elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL: + return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls) + elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE: + return BespokeToolsChatHandler(args) + elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO: + return Hermes2ProToolsChatHandler(args) else: - raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}") - + raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}") + +_ts_converter = SchemaToTypeScriptConverter() + +def _please_respond_with_schema(schema: dict) -> str: + # sig = json.dumps(schema, indent=2) + sig = _ts_converter.visit(schema) + return f'Please respond in JSON format with the following schema: {sig}' + def _tools_typescript_signatures(tools: list[Tool]) -> str: - ts_converter = SchemaToTypeScriptConverter() return 'namespace functions {' + '\n'.join( '// ' + tool.function.description.replace('\n', '\n// ') + '\n' + '' - 'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n" + 'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n" for tool in tools ) + '} // namespace functions' @@ -185,247 +587,9 @@ def _tools_schema_signatures(tools: list[Tool], indent=None) -> str: for tool in tools ) -@typechecked -def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool: - return style in ( - ToolsPromptStyle.TOOLS_SHORT, - ToolsPromptStyle.TOOLS_LONG, - ToolsPromptStyle.TOOLS_HERMES_2_PRO, - ) - _tool_call_re = re.compile( '(.*?)', re.DOTALL) _recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL) def gen_callid(): return f'call_{random.randint(0, 1000000)}' - -@typechecked -def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[list[Message]]]]: - - converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) - - response_rule = converter.visit(response_schema, "response") if response_schema else None - - delimiter = '<%$[SAMPLE]$%>' - user_msg = Message(role="user", content="Hey") - empty_prompt = chat_format.render([user_msg], add_generation_prompt=True).strip() - planted_prompt = chat_format.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip() - assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}" - [prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter) - - allow_parallel_calls = False - - def strip_suffix(s: str) -> str: - if s.endswith(suffix): - return s[:-len(suffix)] - else: - sys.stderr.write(f"Expected suffix ({suffix}) not found: {s}\n") - return s - - if tools: - if _outputs_tool_call_tags(chat_format.tool_style): - - escapes_underscores = chat_format.tool_style != ToolsPromptStyle.TOOLS_HERMES_2_PRO - - tool_rules = [ - converter.visit( - dict( - type="object", - properties=dict( - name=dict(type="string", pattern='^' + tool.function.name.replace('_', f'\\?_') + '$') if escapes_underscores \ - else dict(const=tool.function.name), - arguments=tool.function.parameters, - ), - required=['name', 'arguments'] - ), - f'{tool.function.name}-tool-call' - ) - for tool in tools - ] - - def format_literal(s: str) -> str: - if escapes_underscores: - return ' "\\\\"? "_" '.join((converter._format_literal(part) for part in s.split('_'))) - else: - return converter._format_literal(s) - - tool_call_rule = converter._add_rule( - 'tool_call', - format_literal("") + " space (" + - ' | '.join(tool_rules) + - ") space " + format_literal(""))# + ' space') - - # Ideally we'd want a negative lookahead of //, but it's just too hard to express in GBNF for now. - # So we just over-constrain the content rule to not contain literals dangerously getting close to - content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "')) - converter._add_rule( - 'root', - # tool_call_rule) - f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \ - else f'{content_rule}* {tool_call_rule}?') - - # # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not) - # # OR a tool-call message respecting the schema of any of the tools - # converter._add_rule( - # "root", - # converter._format_literal(prefix) + " (" + - # (response_rule or converter.not_literal("")) + " | " + - # converter._format_literal("") + " (" + - # ' | '.join(tool_rules) + - # ") " + converter._format_literal("") + - # ")") # + converter._format_literal(suffix)) - - @typechecked - def parse(s: str) -> Optional[Message]: - s = strip_suffix(s) - - if r'' in s: - # Some weird escaping of underscores is happening w/ Mixtral 8x7B Instruct - s = s.replace(r'\_', '_') - - parts = _tool_call_re.split(s) - if len(parts) == 1: - return Message(role="assistant", content=s) - else: - content = [] - tool_calls = [] - for i, part in enumerate(parts): - if i % 2 == 0: - content.append(part) - else: - try: - fc = json.loads(part) - except json.JSONDecodeError: - raise ValueError(f'Failed to parse tool call as JSON: {part}\nFull string: {s}') - tool_calls.append( - ToolCall( - id=gen_callid(), - function=FunctionCall(**fc))) - - content = '\n'.join(content).strip() - return Message(role="assistant", content=content if content else None, tool_calls=tool_calls) - - # if ''.startswith(ls) or ls.startswith(''): - # if ls.startswith('') and ls.endswith('' + suffix): - # tool_call = ls[len(''):-len('' + suffix)] - # return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)]) - # return None - # else: - # return Message(role="assistant", content=s) - - return (converter.format_grammar(), parse) - - elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2: - # Only allowing a single tool call at a time for now. - # Note that if there were more, they'd be separated by a '<|from|>assistant' literal - - tool_rules = [ - converter._add_rule( - tool.function.name + '-call', - converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + - converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' + - converter._format_literal('\n')) - # converter.visit( - # dict( - # type="object", - # properties=dict( - # name=dict(const=tool.function.name), - # arguments=tool.function.parameters, - # ), - # required=['name', 'arguments'] - # ), - # f'{tool.function.name}-tool-call' - # ) - for i, tool in enumerate(tools) - ] - - not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>")) - content_without_start_rule = converter._add_rule( - 'content_without_start', - converter._format_literal("all\n<|content|>") + ' ' + not_from_rule + '*') - start_rule = converter._add_rule('start', converter._format_literal('<|from|>assistant\n<|recipient|>')) - content_rule = converter._add_rule('content', start_rule + ' ' + content_without_start_rule) - tool_call_without_start_rule = converter._add_rule( - 'tool_call_without_start', - ' | '.join(tool_rules)) - # + ' ' + - # converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*') - tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}') - # converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*') - converter._add_rule( - 'root', - f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | ' - f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \ - else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}') - - # converter._add_rule( - # "root", - # converter._format_literal(prefix) + " (" + - # (response_rule or converter.not_literal("<|recipient|>")) + " | " + - # (' | '.join( - # converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " + - # converter.visit(tool.function.parameters, tool.function.name + '-args') - # for tool in tools - # )) + - # ") " + - # ")") # + converter._format_literal(suffix)) - - @typechecked - def parse(s: str) -> Optional[Message]: - s = strip_suffix(s) - - parts = _recipient_content_re.split(s) - if len(parts) == 1: - return Message(role="assistant", content=s) - else: - text_content = [] - tool_calls: list[ToolCall] = [] - for i in range((len(parts) - 1) // 3): - assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}' - recipient = parts[i * 3 + 1].strip() - content = parts[i * 3 + 2] - if recipient == 'all': - text_content.append(content) - else: - try: - arguments = json.loads(content) - except json.JSONDecodeError: - raise ValueError(f'Failed to parse tool call content as JSON: {content}') - tool_calls.append( - ToolCall( - id=gen_callid(), - function=FunctionCall(name=recipient, arguments=arguments))) - - - assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}' - - content = '\n'.join(text_content).strip() - return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None) - - return (converter.format_grammar(), parse) - - else: - raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}") - - elif response_schema: - converter._add_rule("root", response_rule + ' ' + converter._format_literal(suffix)) - - @typechecked - def parse(s: str) -> Optional[Message]: - s = strip_suffix(s) - return Message(role="assistant", content=s) - - return (converter.format_grammar(), parse) - - else: - converter._add_rule("root", converter._format_literal(prefix) + ' ' + converter._format_literal(suffix)) - - @typechecked - def parse(s: str) -> Optional[Message]: - s = strip_suffix(s) - return Message(role="assistant", content=s) - - return (None, parse) - diff --git a/examples/openai/server.py b/examples/openai/server.py index ad3910625..fbd2f22da 100644 --- a/examples/openai/server.py +++ b/examples/openai/server.py @@ -5,12 +5,14 @@ import json, sys, subprocess, atexit from pathlib import Path import time +from pydantic import TypeAdapter + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest from examples.openai.gguf_kvs import GGUFKeyValues, Keys from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage -from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt +from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, get_chat_handler, ChatHandler from fastapi import FastAPI, Request from fastapi.responses import JSONResponse @@ -38,8 +40,8 @@ def main( metadata = GGUFKeyValues(model) context_length = metadata[Keys.LLM.CONTEXT_LENGTH] - chat_format = ChatFormat.from_gguf(metadata) - # print(chat_format) + chat_template = ChatTemplate.from_gguf(metadata) + # print(chat_template) if not cpp_server_endpoint: sys.stderr.write(f"# Starting C++ server with model {model} on {cpp_server_host}:{cpp_server_port}\n") @@ -69,18 +71,17 @@ def main( else: response_schema = None - messages = chat_request.messages - if chat_request.tools: - messages = chat_format.add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools)) - - (grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema) - - # TODO: Test whether the template supports formatting tool_calls - - prompt = chat_format.render(messages, add_generation_prompt=True) + chat_handler = get_chat_handler(ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools)) + messages = chat_request.messages + if chat_handler.output_format_prompt: + messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt) + + prompt = chat_template.render(messages, add_generation_prompt=True) + + sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n') sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n') - sys.stderr.write(f'\n# GRAMMAR:\n\n{grammar}\n\n') + sys.stderr.write(f'\n# GRAMMAR:\n\n{chat_handler.grammar}\n\n') data = LlamaCppServerCompletionRequest( **{ @@ -94,9 +95,10 @@ def main( ) }, prompt=prompt, - grammar=grammar, + grammar=chat_handler.grammar, ).model_dump() - sys.stderr.write(json.dumps(data, indent=2) + "\n") + # sys.stderr.write(json.dumps(data, indent=2) + "\n") + async with httpx.AsyncClient() as client: response = await client.post( f"{cpp_server_endpoint}/completions", @@ -116,7 +118,7 @@ def main( return JSONResponse(result) # print(json.dumps(result.get('content'), indent=2)) - message = parser(result["content"]) + message = chat_handler.parse(result["content"]) assert message is not None, f"Failed to parse response:\n{response.text}\n\n" prompt_tokens=result['timings']['prompt_n']