server.py: refactor chat handlers

This commit is contained in:
ochafik 2024-03-29 02:47:33 +00:00
parent 5f3de16116
commit 59b411406f
3 changed files with 487 additions and 321 deletions

View file

@ -18,7 +18,7 @@ class Message(BaseModel):
class ToolFunction(BaseModel): class ToolFunction(BaseModel):
name: str name: str
description: str description: str
parameters: Any parameters: dict[str, Any]
class Tool(BaseModel): class Tool(BaseModel):
type: str type: str

View file

@ -1,11 +1,14 @@
from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from functools import wraps
import jinja2 import jinja2
import json import json
from pathlib import Path from pathlib import Path
import random import random
import re import re
import sys import sys
from typing import Optional, Tuple, Callable from typing import Any, Dict, Literal, Optional, Tuple, Callable, Union
from pydantic import BaseModel
from typeguard import typechecked from typeguard import typechecked
from examples.json_schema_to_grammar import SchemaConverter from examples.json_schema_to_grammar import SchemaConverter
@ -18,22 +21,52 @@ def raise_exception(msg: str):
raise Exception(msg) raise Exception(msg)
@typechecked @typechecked
class ChatFormat: class ChatTemplate(BaseModel):
def __init__(self, template: str, eos_token: str, bos_token: str): template: str
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
self.template = env.from_string(template)
self.eos_token = eos_token
self.bos_token = bos_token
self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template @property
def tool_style(self) -> 'ToolsPromptStyle':
return self._tool_style
def __init__(self, template: str, eos_token: str, bos_token: str):
super().__init__(template=template
)
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
self._template = env.from_string(template)
self._eos_token = eos_token
self._bos_token = bos_token
self._strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template
if "<|recipient|>' + tool_call['function']['name']" in template: if "<|recipient|>' + tool_call['function']['name']" in template:
self.tool_style = ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2 self._tool_style = ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2
else: else:
self.tool_style = ToolsPromptStyle.TOOLS_LONG self._tool_style = ToolsPromptStyle.TOOLS_BESPOKE
# self._tool_style = ToolsPromptStyle.TOOLS_LONG
# TODO: Test whether the template supports formatting tool_calls
delimiter = '<%$[SAMPLE]$%>'
user_msg = Message(role="user", content="Hey")
empty_prompt = self.render([user_msg], add_generation_prompt=True).strip()
planted_prompt = self.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
sys.stderr.write(f"\n# prefix={prefix}\n# suffix={suffix}\n\n")
self._prefix = prefix
self._suffix = suffix
def strip_suffix(self, s: str) -> str:
if s.endswith(self._suffix):
return s[:-len(self._suffix)]
else:
sys.stderr.write(f"Expected suffix ({self._suffix}) not found: {s}\n")
return s
def __str__(self): def __str__(self):
return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})" return f"ChatTemplate(template={self.template}, eos_token={self._eos_token}, bos_token={self._bos_token})"
def add_system_prompt(self, messages: list[Message], system_prompt: Message) -> list[Message]: def add_system_prompt(self, messages: list[Message], system_prompt: Message) -> list[Message]:
assert system_prompt.role == "system" assert system_prompt.role == "system"
@ -48,13 +81,13 @@ class ChatFormat:
@staticmethod @staticmethod
def from_gguf(metadata: GGUFKeyValues): def from_gguf(metadata: GGUFKeyValues):
tokens = metadata[Keys.Tokenizer.LIST] tokens = metadata[Keys.Tokenizer.LIST]
return ChatFormat( return ChatTemplate(
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE], template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
bos_token = tokens[metadata[Keys.Tokenizer.BOS_ID]], bos_token = tokens[metadata[Keys.Tokenizer.BOS_ID]],
eos_token = tokens[metadata[Keys.Tokenizer.EOS_ID]]) eos_token = tokens[metadata[Keys.Tokenizer.EOS_ID]])
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False): def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
if self.strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages): if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
new_messages=[] new_messages=[]
i = 0 i = 0
n = len(messages) n = len(messages)
@ -80,10 +113,10 @@ class ChatFormat:
messages = new_messages messages = new_messages
# print(f'messages={messages}') # print(f'messages={messages}')
result = self.template.render( result = self._template.render(
messages=messages, messages=messages,
eos_token=self.eos_token, eos_token=self._eos_token,
bos_token='' if omit_bos else self.bos_token, bos_token='' if omit_bos else self._bos_token,
raise_exception=raise_exception, raise_exception=raise_exception,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
) )
@ -95,139 +128,71 @@ class ChatFormat:
# each model may need specific prompting (and/or constrained output, # each model may need specific prompting (and/or constrained output,
# especially for models not fine-tuned for tool usage / function calling). # especially for models not fine-tuned for tool usage / function calling).
class ToolsPromptStyle(Enum): class ToolsPromptStyle(Enum):
# Short prompt w/ <tools>schemas</tools> # Short prompt w/ <tools>schemas</tools>, <tool_call>...</tool_call> output
TOOLS_SHORT = 1 TOOLS_SHORT = 1
# Longer prompt w/ <tools>schemas</tools> # Longer prompt w/ <tools>schemas</tools>, <tool_call>...</tool_call> output
TOOLS_LONG = 2 TOOLS_LONG = 2
# Bespoke constrained output format that favours thought and reasoning
# while allowing unambiguous parsing of parallel tool calling.
TOOLS_BESPOKE = 3
# Large prompt for https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B # Large prompt for https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B
# <tool_call>...</tool_call> output
# Requires: # Requires:
# - git clone https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling # - git clone https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling
# - Set large context length as their prompts are super long # - Set large context length as their prompts are super long
TOOLS_HERMES_2_PRO = 3 TOOLS_HERMES_2_PRO = 4
# Seems to want to escape underscores in tool names and in the <tool\_call>...</tool\_call> tags
TOOLS_MISTRAL = 5
# Short prompt w/ TypeScript definitions for https://github.com/MeetKai/functionary # Short prompt w/ TypeScript definitions for https://github.com/MeetKai/functionary
# https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py # https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py
# Note: see this prior attempt to support Functionary: https://github.com/ggerganov/llama.cpp/pull/5695 # Note: see this prior attempt to support Functionary: https://github.com/ggerganov/llama.cpp/pull/5695
TYPESCRIPT_FUNCTIONARY_V2 = 4 TYPESCRIPT_FUNCTIONARY_V2 = 6
@typechecked class ChatHandlerArgs(BaseModel):
def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> Message: chat_template: ChatTemplate
response_schema: Optional[dict] = None
tools: Optional[list[Tool]] = None
if chat_format.tool_style == ToolsPromptStyle.TOOLS_SHORT: class ChatHandler(ABC):
return Message( def __init__(self, args: ChatHandlerArgs):
self.args = args
self.output_format_prompt: Optional[Message] = None
self.grammar: Optional[str] = None
@abstractmethod
def parse(self, s: str) -> Optional[Message]:
raise NotImplementedError()
class NoToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs):
super().__init__(args)
assert not args.tools
if args.response_schema:
self.output_format_prompt = Message(
role="system", role="system",
content='\n'.join([ content=_please_respond_with_schema(args.response_schema)
'Here are the tools available:',
'<tools>',
*(json.dumps(tool.model_dump(), indent=indent) for tool in tools),
'</tools>',
])
) )
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
elif chat_format.tool_style == ToolsPromptStyle.TOOLS_LONG: self.grammar = converter.visit(args.response_schema, '')
return Message(
role="system",
content='\n'.join([
# '''You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.''',
'''You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:''',
'''<tools>''',
_tools_typescript_signatures(tools),
# _tools_schema_signatures(tools, indent=indent),
'''</tools>''',
'',
# '''Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}''',
# '',
# '''For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:''',
'''To call each function, give its name and arguments within <tool_call></tool_call> XML tags as follows:''',
'''<tool_call>''',
'''{"name": <function-name>, "arguments": <args-dict>}''',
'''</tool_call>''',
# '''This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
])
)
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
return Message(
role="system",
content= '// Supported function definitions that should be called when necessary.\n' +
_tools_typescript_signatures(tools)
)
elif chat_format.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
path = str(Path(__file__).parent / "hermes_function_calling")
if path not in sys.path: sys.path.insert(0, path)
try:
from examples.openai.hermes_function_calling.prompter import PromptManager
except ImportError:
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools])
assert len(prompt) == 1 and prompt[0]["role"] == "system"
return Message(**prompt[0])
else: else:
raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}") self.output_format_prompt = None
self.grammar = None
def _tools_typescript_signatures(tools: list[Tool]) -> str:
ts_converter = SchemaToTypeScriptConverter()
return 'namespace functions {' + '\n'.join(
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n"
for tool in tools
) + '} // namespace functions'
def _tools_schema_signatures(tools: list[Tool], indent=None) -> str:
return '\n'.join(
json.dumps(tool.model_dump(), indent=indent)
for tool in tools
)
@typechecked @typechecked
def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool: def parse(self, s: str) -> Optional[Message]:
return style in ( return Message(role="assistant", content=s)
ToolsPromptStyle.TOOLS_SHORT,
ToolsPromptStyle.TOOLS_LONG,
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
)
_tool_call_re = re.compile( class ToolCallTagsChatHandler(ChatHandler):
'<tool_call>(.*?)</tool_call>', re.DOTALL) def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, allow_parallel_calls: bool):
_recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL) super().__init__(args)
def gen_callid():
return f'call_{random.randint(0, 1000000)}'
@typechecked
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[list[Message]]]]:
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
response_rule = converter.visit(response_schema, "response") if response_schema else None
delimiter = '<%$[SAMPLE]$%>'
user_msg = Message(role="user", content="Hey")
empty_prompt = chat_format.render([user_msg], add_generation_prompt=True).strip()
planted_prompt = chat_format.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
allow_parallel_calls = False
def strip_suffix(s: str) -> str:
if s.endswith(suffix):
return s[:-len(suffix)]
else:
sys.stderr.write(f"Expected suffix ({suffix}) not found: {s}\n")
return s
if tools:
if _outputs_tool_call_tags(chat_format.tool_style):
escapes_underscores = chat_format.tool_style != ToolsPromptStyle.TOOLS_HERMES_2_PRO
tool_rules = [ tool_rules = [
converter.visit( converter.visit(
dict( dict(
@ -241,7 +206,7 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
), ),
f'{tool.function.name}-tool-call' f'{tool.function.name}-tool-call'
) )
for tool in tools for tool in self.args.tools
] ]
def format_literal(s: str) -> str: def format_literal(s: str) -> str:
@ -265,6 +230,7 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
# tool_call_rule) # tool_call_rule)
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \ f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \
else f'{content_rule}* {tool_call_rule}?') else f'{content_rule}* {tool_call_rule}?')
self.grammar = converter.format_grammar()
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not) # # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
# # OR a tool-call message respecting the schema of any of the tools # # OR a tool-call message respecting the schema of any of the tools
@ -278,8 +244,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
# ")") # + converter._format_literal(suffix)) # ")") # + converter._format_literal(suffix))
@typechecked @typechecked
def parse(s: str) -> Optional[Message]: def parse(self, s: str) -> Optional[Message]:
s = strip_suffix(s) s = self.args.chat_template.strip_suffix(s)
if r'<tool\_call>' in s: if r'<tool\_call>' in s:
# Some weird escaping of underscores is happening w/ Mixtral 8x7B Instruct # Some weird escaping of underscores is happening w/ Mixtral 8x7B Instruct
@ -315,12 +281,49 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
# else: # else:
# return Message(role="assistant", content=s) # return Message(role="assistant", content=s)
return (converter.format_grammar(), parse) class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs, template: str, escapes_underscores=False, allow_parallel_calls=True):
super().__init__(args, escapes_underscores=escapes_underscores, allow_parallel_calls=allow_parallel_calls)
assert '{tools}' in template, 'Template must contain "{tools}"'
self.output_format_prompt = Message(
role="system",
content=template.replace(
'{tools}',
'\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in self.args.tools),
)
)
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs):
super().__init__(args, escapes_underscores=False, allow_parallel_calls=False)
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
path = str(Path(__file__).parent / "hermes_function_calling")
if path not in sys.path: sys.path.insert(0, path)
try:
from examples.openai.hermes_function_calling.prompter import PromptManager
except ImportError:
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools])
assert len(prompt) == 1 and prompt[0]["role"] == "system"
self.output_format_prompt = Message(**prompt[0])
class FunctionaryToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
super().__init__(args)
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
# Only allowing a single tool call at a time for now. # Only allowing a single tool call at a time for now.
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal # Note that if there were more, they'd be separated by a '<|from|>assistant' literal
self.output_format_prompt = Message(
role="system",
content= '// Supported function definitions that should be called when necessary.\n' +
_tools_typescript_signatures(args.tools)
)
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
tool_rules = [ tool_rules = [
converter._add_rule( converter._add_rule(
tool.function.name + '-call', tool.function.name + '-call',
@ -338,7 +341,7 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
# ), # ),
# f'{tool.function.name}-tool-call' # f'{tool.function.name}-tool-call'
# ) # )
for i, tool in enumerate(tools) for i, tool in enumerate(self.args.tools)
] ]
not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>")) not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
@ -360,6 +363,7 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \ f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}') else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
self.grammar = converter.format_grammar()
# converter._add_rule( # converter._add_rule(
# "root", # "root",
# converter._format_literal(prefix) + " (" + # converter._format_literal(prefix) + " (" +
@ -373,8 +377,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
# ")") # + converter._format_literal(suffix)) # ")") # + converter._format_literal(suffix))
@typechecked @typechecked
def parse(s: str) -> Optional[Message]: def parse(self, s: str) -> Optional[Message]:
s = strip_suffix(s) s = self.args.chat_template.strip_suffix(s)
parts = _recipient_content_re.split(s) parts = _recipient_content_re.split(s)
if len(parts) == 1: if len(parts) == 1:
@ -404,28 +408,188 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
content = '\n'.join(text_content).strip() content = '\n'.join(text_content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None) return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
return (converter.format_grammar(), parse) def _make_bespoke_schema(response_schema, tool_call_schema):
return {
"type": "object",
"properties": {
# "original_goal": {"title": "Original Goal", "type": "string"},
"thought": {
# "title": "Thought about how the next step brings us closer to achieving the original goal",
"type": "string"
},
"next_step": {
"title": "Next Step: either a result or one or more tool calls to achieve the original goal",
"oneOf": [
{
"title": "Tool Calls",
"properties": {
# "type": {
# "const": "tool_calls"
# },
"tool_calls": {
"type": "array",
"items": tool_call_schema
}
},
"required": ["tool_calls"]
},
{
"title": "Result (achieving original goal)",
"properties": {
"result": response_schema,
},
"required": ["result"]
},
]
},
},
"required": ["original_goal", "thought", "next_step"]
}
else: class BespokeToolsChatHandler(ChatHandler):
raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}") def __init__(self, args: ChatHandlerArgs):
super().__init__(args)
elif response_schema: # args.response_schema = args.response_schema or {}
converter._add_rule("root", response_rule + ' ' + converter._format_literal(suffix)) converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
response_schema = args.response_schema or {"type": "string"}
converter.visit(
_make_bespoke_schema(
response_schema,
{
"oneOf": [
{
"type": "object",
"properties": {
"name": {"const": tool.function.name},
"arguments": tool.function.parameters,
},
"required": ["name", "arguments"]
}
for tool in self.args.tools
]
}
),
'',
)
self.grammar = converter.format_grammar()
self.output_format_prompt = Message(
role="system",
content='\n'.join([
'You are a function calling AI model.',
'Here are the tools available:',
_tools_schema_signatures(self.args.tools, indent=2),
_please_respond_with_schema(
_make_bespoke_schema(
response_schema,
{
"properties": {
"name": {
"title": "Name of the tool to call",
"type": "string"
},
"arguments": {
"title": "Arguments to pass to the tool",
"type": "object"
}
},
"required": ["name", "arguments"]
}
)
),
])
)
@typechecked @typechecked
def parse(s: str) -> Optional[Message]: def parse(self, s: str) -> Optional[Message]:
s = strip_suffix(s) s = self.args.chat_template.strip_suffix(s)
return Message(role="assistant", content=s) try:
data = json.loads(s)
return (converter.format_grammar(), parse) except json.JSONDecodeError:
raise ValueError(f'Failed to parse data as JSON: {s}')
next_step = data['next_step']
if 'result' in next_step:
return Message(role="assistant", content=json.dumps(next_step['result']))
elif 'tool_calls' in next_step:
return Message(
role="assistant",
content=data["thought"],
tool_calls=[
ToolCall(id=gen_callid(), function=FunctionCall(**tc))
for tc in next_step['tool_calls']
]
)
else: else:
converter._add_rule("root", converter._format_literal(prefix) + ' ' + converter._format_literal(suffix)) raise ValueError(f'Unexpected data: {data}')
@typechecked _SHORT_TEMPLATE='\n'.join([
def parse(s: str) -> Optional[Message]: 'Here are the tools available:',
s = strip_suffix(s) '<tools>',
return Message(role="assistant", content=s) '{tools}',
'</tools>',
])
return (None, parse) _LONG_TEMPLATE='\n'.join([
# '''You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.''',
'You may call one or more functions to assist with the user query. Don\'t make assumptions about what values to plug into functions. Here are the available tools:',
'<tools>',
'{tools}',
'</tools>',
'',
# 'Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}',
# '',
# 'For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:',
'To call each function, give its name and arguments within <tool_call></tool_call> XML tags as follows:',
'<tool_call>',
'{"name": <function-name>, "arguments": <args-dict>}',
'</tool_call>',
# 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
])
def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatHandler:
if not args.tools:
return NoToolsChatHandler(args)
elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
return FunctionaryToolsChatHandler(args)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT:
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
return BespokeToolsChatHandler(args)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
return Hermes2ProToolsChatHandler(args)
else:
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
_ts_converter = SchemaToTypeScriptConverter()
def _please_respond_with_schema(schema: dict) -> str:
# sig = json.dumps(schema, indent=2)
sig = _ts_converter.visit(schema)
return f'Please respond in JSON format with the following schema: {sig}'
def _tools_typescript_signatures(tools: list[Tool]) -> str:
return 'namespace functions {' + '\n'.join(
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"
for tool in tools
) + '} // namespace functions'
def _tools_schema_signatures(tools: list[Tool], indent=None) -> str:
return '\n'.join(
json.dumps(tool.model_dump(), indent=indent)
for tool in tools
)
_tool_call_re = re.compile(
'<tool_call>(.*?)</tool_call>', re.DOTALL)
_recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL)
def gen_callid():
return f'call_{random.randint(0, 1000000)}'

View file

@ -5,12 +5,14 @@ import json, sys, subprocess, atexit
from pathlib import Path from pathlib import Path
import time import time
from pydantic import TypeAdapter
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
from examples.openai.gguf_kvs import GGUFKeyValues, Keys from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage
from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, get_chat_handler, ChatHandler
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -38,8 +40,8 @@ def main(
metadata = GGUFKeyValues(model) metadata = GGUFKeyValues(model)
context_length = metadata[Keys.LLM.CONTEXT_LENGTH] context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
chat_format = ChatFormat.from_gguf(metadata) chat_template = ChatTemplate.from_gguf(metadata)
# print(chat_format) # print(chat_template)
if not cpp_server_endpoint: if not cpp_server_endpoint:
sys.stderr.write(f"# Starting C++ server with model {model} on {cpp_server_host}:{cpp_server_port}\n") sys.stderr.write(f"# Starting C++ server with model {model} on {cpp_server_host}:{cpp_server_port}\n")
@ -69,18 +71,17 @@ def main(
else: else:
response_schema = None response_schema = None
chat_handler = get_chat_handler(ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools))
messages = chat_request.messages messages = chat_request.messages
if chat_request.tools: if chat_handler.output_format_prompt:
messages = chat_format.add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools)) messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema) prompt = chat_template.render(messages, add_generation_prompt=True)
# TODO: Test whether the template supports formatting tool_calls
prompt = chat_format.render(messages, add_generation_prompt=True)
sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n')
sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n') sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n')
sys.stderr.write(f'\n# GRAMMAR:\n\n{grammar}\n\n') sys.stderr.write(f'\n# GRAMMAR:\n\n{chat_handler.grammar}\n\n')
data = LlamaCppServerCompletionRequest( data = LlamaCppServerCompletionRequest(
**{ **{
@ -94,9 +95,10 @@ def main(
) )
}, },
prompt=prompt, prompt=prompt,
grammar=grammar, grammar=chat_handler.grammar,
).model_dump() ).model_dump()
sys.stderr.write(json.dumps(data, indent=2) + "\n") # sys.stderr.write(json.dumps(data, indent=2) + "\n")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{cpp_server_endpoint}/completions", f"{cpp_server_endpoint}/completions",
@ -116,7 +118,7 @@ def main(
return JSONResponse(result) return JSONResponse(result)
# print(json.dumps(result.get('content'), indent=2)) # print(json.dumps(result.get('content'), indent=2))
message = parser(result["content"]) message = chat_handler.parse(result["content"])
assert message is not None, f"Failed to parse response:\n{response.text}\n\n" assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
prompt_tokens=result['timings']['prompt_n'] prompt_tokens=result['timings']['prompt_n']