server.py: refactor chat handlers
This commit is contained in:
parent
5f3de16116
commit
59b411406f
3 changed files with 487 additions and 321 deletions
|
@ -18,7 +18,7 @@ class Message(BaseModel):
|
||||||
class ToolFunction(BaseModel):
|
class ToolFunction(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
parameters: Any
|
parameters: dict[str, Any]
|
||||||
|
|
||||||
class Tool(BaseModel):
|
class Tool(BaseModel):
|
||||||
type: str
|
type: str
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from functools import wraps
|
||||||
import jinja2
|
import jinja2
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional, Tuple, Callable
|
from typing import Any, Dict, Literal, Optional, Tuple, Callable, Union
|
||||||
|
from pydantic import BaseModel
|
||||||
from typeguard import typechecked
|
from typeguard import typechecked
|
||||||
|
|
||||||
from examples.json_schema_to_grammar import SchemaConverter
|
from examples.json_schema_to_grammar import SchemaConverter
|
||||||
|
@ -18,22 +21,52 @@ def raise_exception(msg: str):
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
|
|
||||||
@typechecked
|
@typechecked
|
||||||
class ChatFormat:
|
class ChatTemplate(BaseModel):
|
||||||
def __init__(self, template: str, eos_token: str, bos_token: str):
|
template: str
|
||||||
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
|
|
||||||
self.template = env.from_string(template)
|
|
||||||
self.eos_token = eos_token
|
|
||||||
self.bos_token = bos_token
|
|
||||||
|
|
||||||
self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template
|
@property
|
||||||
|
def tool_style(self) -> 'ToolsPromptStyle':
|
||||||
|
return self._tool_style
|
||||||
|
|
||||||
|
def __init__(self, template: str, eos_token: str, bos_token: str):
|
||||||
|
super().__init__(template=template
|
||||||
|
)
|
||||||
|
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
|
||||||
|
self._template = env.from_string(template)
|
||||||
|
self._eos_token = eos_token
|
||||||
|
self._bos_token = bos_token
|
||||||
|
|
||||||
|
self._strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template
|
||||||
|
|
||||||
if "<|recipient|>' + tool_call['function']['name']" in template:
|
if "<|recipient|>' + tool_call['function']['name']" in template:
|
||||||
self.tool_style = ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2
|
self._tool_style = ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2
|
||||||
else:
|
else:
|
||||||
self.tool_style = ToolsPromptStyle.TOOLS_LONG
|
self._tool_style = ToolsPromptStyle.TOOLS_BESPOKE
|
||||||
|
# self._tool_style = ToolsPromptStyle.TOOLS_LONG
|
||||||
|
|
||||||
|
# TODO: Test whether the template supports formatting tool_calls
|
||||||
|
|
||||||
|
delimiter = '<%$[SAMPLE]$%>'
|
||||||
|
user_msg = Message(role="user", content="Hey")
|
||||||
|
empty_prompt = self.render([user_msg], add_generation_prompt=True).strip()
|
||||||
|
planted_prompt = self.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
|
||||||
|
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
||||||
|
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
||||||
|
|
||||||
|
sys.stderr.write(f"\n# prefix={prefix}\n# suffix={suffix}\n\n")
|
||||||
|
|
||||||
|
self._prefix = prefix
|
||||||
|
self._suffix = suffix
|
||||||
|
|
||||||
|
def strip_suffix(self, s: str) -> str:
|
||||||
|
if s.endswith(self._suffix):
|
||||||
|
return s[:-len(self._suffix)]
|
||||||
|
else:
|
||||||
|
sys.stderr.write(f"Expected suffix ({self._suffix}) not found: {s}\n")
|
||||||
|
return s
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})"
|
return f"ChatTemplate(template={self.template}, eos_token={self._eos_token}, bos_token={self._bos_token})"
|
||||||
|
|
||||||
def add_system_prompt(self, messages: list[Message], system_prompt: Message) -> list[Message]:
|
def add_system_prompt(self, messages: list[Message], system_prompt: Message) -> list[Message]:
|
||||||
assert system_prompt.role == "system"
|
assert system_prompt.role == "system"
|
||||||
|
@ -48,13 +81,13 @@ class ChatFormat:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_gguf(metadata: GGUFKeyValues):
|
def from_gguf(metadata: GGUFKeyValues):
|
||||||
tokens = metadata[Keys.Tokenizer.LIST]
|
tokens = metadata[Keys.Tokenizer.LIST]
|
||||||
return ChatFormat(
|
return ChatTemplate(
|
||||||
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
|
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
|
||||||
bos_token = tokens[metadata[Keys.Tokenizer.BOS_ID]],
|
bos_token = tokens[metadata[Keys.Tokenizer.BOS_ID]],
|
||||||
eos_token = tokens[metadata[Keys.Tokenizer.EOS_ID]])
|
eos_token = tokens[metadata[Keys.Tokenizer.EOS_ID]])
|
||||||
|
|
||||||
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
||||||
if self.strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
|
if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
|
||||||
new_messages=[]
|
new_messages=[]
|
||||||
i = 0
|
i = 0
|
||||||
n = len(messages)
|
n = len(messages)
|
||||||
|
@ -80,10 +113,10 @@ class ChatFormat:
|
||||||
messages = new_messages
|
messages = new_messages
|
||||||
# print(f'messages={messages}')
|
# print(f'messages={messages}')
|
||||||
|
|
||||||
result = self.template.render(
|
result = self._template.render(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
eos_token=self.eos_token,
|
eos_token=self._eos_token,
|
||||||
bos_token='' if omit_bos else self.bos_token,
|
bos_token='' if omit_bos else self._bos_token,
|
||||||
raise_exception=raise_exception,
|
raise_exception=raise_exception,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
)
|
)
|
||||||
|
@ -95,67 +128,176 @@ class ChatFormat:
|
||||||
# each model may need specific prompting (and/or constrained output,
|
# each model may need specific prompting (and/or constrained output,
|
||||||
# especially for models not fine-tuned for tool usage / function calling).
|
# especially for models not fine-tuned for tool usage / function calling).
|
||||||
class ToolsPromptStyle(Enum):
|
class ToolsPromptStyle(Enum):
|
||||||
# Short prompt w/ <tools>schemas</tools>
|
# Short prompt w/ <tools>schemas</tools>, <tool_call>...</tool_call> output
|
||||||
TOOLS_SHORT = 1
|
TOOLS_SHORT = 1
|
||||||
|
|
||||||
# Longer prompt w/ <tools>schemas</tools>
|
# Longer prompt w/ <tools>schemas</tools>, <tool_call>...</tool_call> output
|
||||||
TOOLS_LONG = 2
|
TOOLS_LONG = 2
|
||||||
|
|
||||||
|
# Bespoke constrained output format that favours thought and reasoning
|
||||||
|
# while allowing unambiguous parsing of parallel tool calling.
|
||||||
|
TOOLS_BESPOKE = 3
|
||||||
|
|
||||||
# Large prompt for https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B
|
# Large prompt for https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B
|
||||||
|
# <tool_call>...</tool_call> output
|
||||||
# Requires:
|
# Requires:
|
||||||
# - git clone https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling
|
# - git clone https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling
|
||||||
# - Set large context length as their prompts are super long
|
# - Set large context length as their prompts are super long
|
||||||
TOOLS_HERMES_2_PRO = 3
|
TOOLS_HERMES_2_PRO = 4
|
||||||
|
|
||||||
|
# Seems to want to escape underscores in tool names and in the <tool\_call>...</tool\_call> tags
|
||||||
|
TOOLS_MISTRAL = 5
|
||||||
|
|
||||||
# Short prompt w/ TypeScript definitions for https://github.com/MeetKai/functionary
|
# Short prompt w/ TypeScript definitions for https://github.com/MeetKai/functionary
|
||||||
# https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py
|
# https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py
|
||||||
# Note: see this prior attempt to support Functionary: https://github.com/ggerganov/llama.cpp/pull/5695
|
# Note: see this prior attempt to support Functionary: https://github.com/ggerganov/llama.cpp/pull/5695
|
||||||
TYPESCRIPT_FUNCTIONARY_V2 = 4
|
TYPESCRIPT_FUNCTIONARY_V2 = 6
|
||||||
|
|
||||||
@typechecked
|
class ChatHandlerArgs(BaseModel):
|
||||||
def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> Message:
|
chat_template: ChatTemplate
|
||||||
|
response_schema: Optional[dict] = None
|
||||||
|
tools: Optional[list[Tool]] = None
|
||||||
|
|
||||||
if chat_format.tool_style == ToolsPromptStyle.TOOLS_SHORT:
|
class ChatHandler(ABC):
|
||||||
return Message(
|
def __init__(self, args: ChatHandlerArgs):
|
||||||
|
self.args = args
|
||||||
|
self.output_format_prompt: Optional[Message] = None
|
||||||
|
self.grammar: Optional[str] = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse(self, s: str) -> Optional[Message]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
class NoToolsChatHandler(ChatHandler):
|
||||||
|
def __init__(self, args: ChatHandlerArgs):
|
||||||
|
super().__init__(args)
|
||||||
|
assert not args.tools
|
||||||
|
|
||||||
|
if args.response_schema:
|
||||||
|
self.output_format_prompt = Message(
|
||||||
|
role="system",
|
||||||
|
content=_please_respond_with_schema(args.response_schema)
|
||||||
|
)
|
||||||
|
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
|
self.grammar = converter.visit(args.response_schema, '')
|
||||||
|
else:
|
||||||
|
self.output_format_prompt = None
|
||||||
|
self.grammar = None
|
||||||
|
|
||||||
|
@typechecked
|
||||||
|
def parse(self, s: str) -> Optional[Message]:
|
||||||
|
return Message(role="assistant", content=s)
|
||||||
|
|
||||||
|
class ToolCallTagsChatHandler(ChatHandler):
|
||||||
|
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, allow_parallel_calls: bool):
|
||||||
|
super().__init__(args)
|
||||||
|
|
||||||
|
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
|
tool_rules = [
|
||||||
|
converter.visit(
|
||||||
|
dict(
|
||||||
|
type="object",
|
||||||
|
properties=dict(
|
||||||
|
name=dict(type="string", pattern='^' + tool.function.name.replace('_', f'\\?_') + '$') if escapes_underscores \
|
||||||
|
else dict(const=tool.function.name),
|
||||||
|
arguments=tool.function.parameters,
|
||||||
|
),
|
||||||
|
required=['name', 'arguments']
|
||||||
|
),
|
||||||
|
f'{tool.function.name}-tool-call'
|
||||||
|
)
|
||||||
|
for tool in self.args.tools
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_literal(s: str) -> str:
|
||||||
|
if escapes_underscores:
|
||||||
|
return ' "\\\\"? "_" '.join((converter._format_literal(part) for part in s.split('_')))
|
||||||
|
else:
|
||||||
|
return converter._format_literal(s)
|
||||||
|
|
||||||
|
tool_call_rule = converter._add_rule(
|
||||||
|
'tool_call',
|
||||||
|
format_literal("<tool_call>") + " space (" +
|
||||||
|
' | '.join(tool_rules) +
|
||||||
|
") space " + format_literal("</tool_call>"))# + ' space')
|
||||||
|
|
||||||
|
# Ideally we'd want a negative lookahead of /<tool\\?_call>/, but it's just too hard to express in GBNF for now.
|
||||||
|
# So we just over-constrain the content rule to not contain literals dangerously getting close to <tool_call>
|
||||||
|
content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "<t" [^o<]')
|
||||||
|
# content_rule = converter._add_rule('content', converter.not_literal('<tool_call>'))
|
||||||
|
converter._add_rule(
|
||||||
|
'root',
|
||||||
|
# tool_call_rule)
|
||||||
|
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \
|
||||||
|
else f'{content_rule}* {tool_call_rule}?')
|
||||||
|
self.grammar = converter.format_grammar()
|
||||||
|
|
||||||
|
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
|
||||||
|
# # OR a tool-call message respecting the schema of any of the tools
|
||||||
|
# converter._add_rule(
|
||||||
|
# "root",
|
||||||
|
# converter._format_literal(prefix) + " (" +
|
||||||
|
# (response_rule or converter.not_literal("<tool_call>")) + " | " +
|
||||||
|
# converter._format_literal("<tool_call>") + " (" +
|
||||||
|
# ' | '.join(tool_rules) +
|
||||||
|
# ") " + converter._format_literal("</tool_call>") +
|
||||||
|
# ")") # + converter._format_literal(suffix))
|
||||||
|
|
||||||
|
@typechecked
|
||||||
|
def parse(self, s: str) -> Optional[Message]:
|
||||||
|
s = self.args.chat_template.strip_suffix(s)
|
||||||
|
|
||||||
|
if r'<tool\_call>' in s:
|
||||||
|
# Some weird escaping of underscores is happening w/ Mixtral 8x7B Instruct
|
||||||
|
s = s.replace(r'\_', '_')
|
||||||
|
|
||||||
|
parts = _tool_call_re.split(s)
|
||||||
|
if len(parts) == 1:
|
||||||
|
return Message(role="assistant", content=s)
|
||||||
|
else:
|
||||||
|
content = []
|
||||||
|
tool_calls = []
|
||||||
|
for i, part in enumerate(parts):
|
||||||
|
if i % 2 == 0:
|
||||||
|
content.append(part)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
fc = json.loads(part)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f'Failed to parse tool call as JSON: {part}\nFull string: {s}')
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=gen_callid(),
|
||||||
|
function=FunctionCall(**fc)))
|
||||||
|
|
||||||
|
content = '\n'.join(content).strip()
|
||||||
|
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
|
||||||
|
|
||||||
|
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
|
||||||
|
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
|
||||||
|
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
|
||||||
|
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
|
||||||
|
# return None
|
||||||
|
# else:
|
||||||
|
# return Message(role="assistant", content=s)
|
||||||
|
|
||||||
|
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
|
def __init__(self, args: ChatHandlerArgs, template: str, escapes_underscores=False, allow_parallel_calls=True):
|
||||||
|
super().__init__(args, escapes_underscores=escapes_underscores, allow_parallel_calls=allow_parallel_calls)
|
||||||
|
assert '{tools}' in template, 'Template must contain "{tools}"'
|
||||||
|
|
||||||
|
self.output_format_prompt = Message(
|
||||||
role="system",
|
role="system",
|
||||||
content='\n'.join([
|
content=template.replace(
|
||||||
'Here are the tools available:',
|
'{tools}',
|
||||||
'<tools>',
|
'\n'.join(json.dumps(tool.model_dump(), indent=2) for tool in self.args.tools),
|
||||||
*(json.dumps(tool.model_dump(), indent=indent) for tool in tools),
|
)
|
||||||
'</tools>',
|
|
||||||
])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif chat_format.tool_style == ToolsPromptStyle.TOOLS_LONG:
|
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
return Message(
|
def __init__(self, args: ChatHandlerArgs):
|
||||||
role="system",
|
super().__init__(args, escapes_underscores=False, allow_parallel_calls=False)
|
||||||
content='\n'.join([
|
|
||||||
# '''You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.''',
|
|
||||||
'''You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:''',
|
|
||||||
'''<tools>''',
|
|
||||||
_tools_typescript_signatures(tools),
|
|
||||||
# _tools_schema_signatures(tools, indent=indent),
|
|
||||||
'''</tools>''',
|
|
||||||
'',
|
|
||||||
# '''Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}''',
|
|
||||||
# '',
|
|
||||||
# '''For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:''',
|
|
||||||
'''To call each function, give its name and arguments within <tool_call></tool_call> XML tags as follows:''',
|
|
||||||
'''<tool_call>''',
|
|
||||||
'''{"name": <function-name>, "arguments": <args-dict>}''',
|
|
||||||
'''</tool_call>''',
|
|
||||||
# '''This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
|
|
||||||
])
|
|
||||||
)
|
|
||||||
|
|
||||||
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
|
||||||
return Message(
|
|
||||||
role="system",
|
|
||||||
content= '// Supported function definitions that should be called when necessary.\n' +
|
|
||||||
_tools_typescript_signatures(tools)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif chat_format.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
|
||||||
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
|
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
|
||||||
path = str(Path(__file__).parent / "hermes_function_calling")
|
path = str(Path(__file__).parent / "hermes_function_calling")
|
||||||
if path not in sys.path: sys.path.insert(0, path)
|
if path not in sys.path: sys.path.insert(0, path)
|
||||||
|
@ -166,16 +308,276 @@ def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> M
|
||||||
|
|
||||||
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools])
|
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools])
|
||||||
assert len(prompt) == 1 and prompt[0]["role"] == "system"
|
assert len(prompt) == 1 and prompt[0]["role"] == "system"
|
||||||
return Message(**prompt[0])
|
self.output_format_prompt = Message(**prompt[0])
|
||||||
|
|
||||||
|
class FunctionaryToolsChatHandler(ChatHandler):
|
||||||
|
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
|
||||||
|
super().__init__(args)
|
||||||
|
|
||||||
|
# Only allowing a single tool call at a time for now.
|
||||||
|
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal
|
||||||
|
|
||||||
|
self.output_format_prompt = Message(
|
||||||
|
role="system",
|
||||||
|
content= '// Supported function definitions that should be called when necessary.\n' +
|
||||||
|
_tools_typescript_signatures(args.tools)
|
||||||
|
)
|
||||||
|
|
||||||
|
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
|
tool_rules = [
|
||||||
|
converter._add_rule(
|
||||||
|
tool.function.name + '-call',
|
||||||
|
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
|
||||||
|
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
|
||||||
|
converter._format_literal('\n'))
|
||||||
|
# converter.visit(
|
||||||
|
# dict(
|
||||||
|
# type="object",
|
||||||
|
# properties=dict(
|
||||||
|
# name=dict(const=tool.function.name),
|
||||||
|
# arguments=tool.function.parameters,
|
||||||
|
# ),
|
||||||
|
# required=['name', 'arguments']
|
||||||
|
# ),
|
||||||
|
# f'{tool.function.name}-tool-call'
|
||||||
|
# )
|
||||||
|
for i, tool in enumerate(self.args.tools)
|
||||||
|
]
|
||||||
|
|
||||||
|
not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
|
||||||
|
content_without_start_rule = converter._add_rule(
|
||||||
|
'content_without_start',
|
||||||
|
converter._format_literal("all\n<|content|>") + ' ' + not_from_rule + '*')
|
||||||
|
start_rule = converter._add_rule('start', converter._format_literal('<|from|>assistant\n<|recipient|>'))
|
||||||
|
content_rule = converter._add_rule('content', start_rule + ' ' + content_without_start_rule)
|
||||||
|
tool_call_without_start_rule = converter._add_rule(
|
||||||
|
'tool_call_without_start',
|
||||||
|
' | '.join(tool_rules))
|
||||||
|
# + ' ' +
|
||||||
|
# converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
|
||||||
|
tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
|
||||||
|
# converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
|
||||||
|
converter._add_rule(
|
||||||
|
'root',
|
||||||
|
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
|
||||||
|
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \
|
||||||
|
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
|
||||||
|
|
||||||
|
self.grammar = converter.format_grammar()
|
||||||
|
# converter._add_rule(
|
||||||
|
# "root",
|
||||||
|
# converter._format_literal(prefix) + " (" +
|
||||||
|
# (response_rule or converter.not_literal("<|recipient|>")) + " | " +
|
||||||
|
# (' | '.join(
|
||||||
|
# converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
|
||||||
|
# converter.visit(tool.function.parameters, tool.function.name + '-args')
|
||||||
|
# for tool in tools
|
||||||
|
# )) +
|
||||||
|
# ") " +
|
||||||
|
# ")") # + converter._format_literal(suffix))
|
||||||
|
|
||||||
|
@typechecked
|
||||||
|
def parse(self, s: str) -> Optional[Message]:
|
||||||
|
s = self.args.chat_template.strip_suffix(s)
|
||||||
|
|
||||||
|
parts = _recipient_content_re.split(s)
|
||||||
|
if len(parts) == 1:
|
||||||
|
return Message(role="assistant", content=s)
|
||||||
|
else:
|
||||||
|
text_content = []
|
||||||
|
tool_calls: list[ToolCall] = []
|
||||||
|
for i in range((len(parts) - 1) // 3):
|
||||||
|
assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}'
|
||||||
|
recipient = parts[i * 3 + 1].strip()
|
||||||
|
content = parts[i * 3 + 2]
|
||||||
|
if recipient == 'all':
|
||||||
|
text_content.append(content)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
arguments = json.loads(content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f'Failed to parse tool call content as JSON: {content}')
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=gen_callid(),
|
||||||
|
function=FunctionCall(name=recipient, arguments=arguments)))
|
||||||
|
|
||||||
|
|
||||||
|
assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}'
|
||||||
|
|
||||||
|
content = '\n'.join(text_content).strip()
|
||||||
|
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
|
||||||
|
|
||||||
|
def _make_bespoke_schema(response_schema, tool_call_schema):
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
# "original_goal": {"title": "Original Goal", "type": "string"},
|
||||||
|
"thought": {
|
||||||
|
# "title": "Thought about how the next step brings us closer to achieving the original goal",
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"next_step": {
|
||||||
|
"title": "Next Step: either a result or one or more tool calls to achieve the original goal",
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"title": "Tool Calls",
|
||||||
|
"properties": {
|
||||||
|
# "type": {
|
||||||
|
# "const": "tool_calls"
|
||||||
|
# },
|
||||||
|
"tool_calls": {
|
||||||
|
"type": "array",
|
||||||
|
"items": tool_call_schema
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["tool_calls"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Result (achieving original goal)",
|
||||||
|
"properties": {
|
||||||
|
"result": response_schema,
|
||||||
|
},
|
||||||
|
"required": ["result"]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["original_goal", "thought", "next_step"]
|
||||||
|
}
|
||||||
|
|
||||||
|
class BespokeToolsChatHandler(ChatHandler):
|
||||||
|
def __init__(self, args: ChatHandlerArgs):
|
||||||
|
super().__init__(args)
|
||||||
|
|
||||||
|
# args.response_schema = args.response_schema or {}
|
||||||
|
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
|
|
||||||
|
response_schema = args.response_schema or {"type": "string"}
|
||||||
|
converter.visit(
|
||||||
|
_make_bespoke_schema(
|
||||||
|
response_schema,
|
||||||
|
{
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"const": tool.function.name},
|
||||||
|
"arguments": tool.function.parameters,
|
||||||
|
},
|
||||||
|
"required": ["name", "arguments"]
|
||||||
|
}
|
||||||
|
for tool in self.args.tools
|
||||||
|
]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
'',
|
||||||
|
)
|
||||||
|
self.grammar = converter.format_grammar()
|
||||||
|
|
||||||
|
self.output_format_prompt = Message(
|
||||||
|
role="system",
|
||||||
|
content='\n'.join([
|
||||||
|
'You are a function calling AI model.',
|
||||||
|
'Here are the tools available:',
|
||||||
|
_tools_schema_signatures(self.args.tools, indent=2),
|
||||||
|
_please_respond_with_schema(
|
||||||
|
_make_bespoke_schema(
|
||||||
|
response_schema,
|
||||||
|
{
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"title": "Name of the tool to call",
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"arguments": {
|
||||||
|
"title": "Arguments to pass to the tool",
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name", "arguments"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
),
|
||||||
|
])
|
||||||
|
)
|
||||||
|
|
||||||
|
@typechecked
|
||||||
|
def parse(self, s: str) -> Optional[Message]:
|
||||||
|
s = self.args.chat_template.strip_suffix(s)
|
||||||
|
try:
|
||||||
|
data = json.loads(s)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f'Failed to parse data as JSON: {s}')
|
||||||
|
|
||||||
|
next_step = data['next_step']
|
||||||
|
if 'result' in next_step:
|
||||||
|
return Message(role="assistant", content=json.dumps(next_step['result']))
|
||||||
|
elif 'tool_calls' in next_step:
|
||||||
|
return Message(
|
||||||
|
role="assistant",
|
||||||
|
content=data["thought"],
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(id=gen_callid(), function=FunctionCall(**tc))
|
||||||
|
for tc in next_step['tool_calls']
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unexpected data: {data}')
|
||||||
|
|
||||||
|
_SHORT_TEMPLATE='\n'.join([
|
||||||
|
'Here are the tools available:',
|
||||||
|
'<tools>',
|
||||||
|
'{tools}',
|
||||||
|
'</tools>',
|
||||||
|
])
|
||||||
|
|
||||||
|
_LONG_TEMPLATE='\n'.join([
|
||||||
|
# '''You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.''',
|
||||||
|
'You may call one or more functions to assist with the user query. Don\'t make assumptions about what values to plug into functions. Here are the available tools:',
|
||||||
|
'<tools>',
|
||||||
|
'{tools}',
|
||||||
|
'</tools>',
|
||||||
|
'',
|
||||||
|
# 'Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}',
|
||||||
|
# '',
|
||||||
|
# 'For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:',
|
||||||
|
'To call each function, give its name and arguments within <tool_call></tool_call> XML tags as follows:',
|
||||||
|
'<tool_call>',
|
||||||
|
'{"name": <function-name>, "arguments": <args-dict>}',
|
||||||
|
'</tool_call>',
|
||||||
|
# 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
|
||||||
|
])
|
||||||
|
|
||||||
|
def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatHandler:
|
||||||
|
if not args.tools:
|
||||||
|
return NoToolsChatHandler(args)
|
||||||
|
elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
||||||
|
return FunctionaryToolsChatHandler(args)
|
||||||
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT:
|
||||||
|
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
|
||||||
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG:
|
||||||
|
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
|
||||||
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
|
||||||
|
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
|
||||||
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
|
||||||
|
return BespokeToolsChatHandler(args)
|
||||||
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
||||||
|
return Hermes2ProToolsChatHandler(args)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}")
|
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
|
||||||
|
|
||||||
|
_ts_converter = SchemaToTypeScriptConverter()
|
||||||
|
|
||||||
|
def _please_respond_with_schema(schema: dict) -> str:
|
||||||
|
# sig = json.dumps(schema, indent=2)
|
||||||
|
sig = _ts_converter.visit(schema)
|
||||||
|
return f'Please respond in JSON format with the following schema: {sig}'
|
||||||
|
|
||||||
def _tools_typescript_signatures(tools: list[Tool]) -> str:
|
def _tools_typescript_signatures(tools: list[Tool]) -> str:
|
||||||
ts_converter = SchemaToTypeScriptConverter()
|
|
||||||
return 'namespace functions {' + '\n'.join(
|
return 'namespace functions {' + '\n'.join(
|
||||||
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
|
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
|
||||||
'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n"
|
'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"
|
||||||
for tool in tools
|
for tool in tools
|
||||||
) + '} // namespace functions'
|
) + '} // namespace functions'
|
||||||
|
|
||||||
|
@ -185,247 +587,9 @@ def _tools_schema_signatures(tools: list[Tool], indent=None) -> str:
|
||||||
for tool in tools
|
for tool in tools
|
||||||
)
|
)
|
||||||
|
|
||||||
@typechecked
|
|
||||||
def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
|
|
||||||
return style in (
|
|
||||||
ToolsPromptStyle.TOOLS_SHORT,
|
|
||||||
ToolsPromptStyle.TOOLS_LONG,
|
|
||||||
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
|
|
||||||
)
|
|
||||||
|
|
||||||
_tool_call_re = re.compile(
|
_tool_call_re = re.compile(
|
||||||
'<tool_call>(.*?)</tool_call>', re.DOTALL)
|
'<tool_call>(.*?)</tool_call>', re.DOTALL)
|
||||||
_recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL)
|
_recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL)
|
||||||
|
|
||||||
def gen_callid():
|
def gen_callid():
|
||||||
return f'call_{random.randint(0, 1000000)}'
|
return f'call_{random.randint(0, 1000000)}'
|
||||||
|
|
||||||
@typechecked
|
|
||||||
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[list[Message]]]]:
|
|
||||||
|
|
||||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
|
||||||
|
|
||||||
response_rule = converter.visit(response_schema, "response") if response_schema else None
|
|
||||||
|
|
||||||
delimiter = '<%$[SAMPLE]$%>'
|
|
||||||
user_msg = Message(role="user", content="Hey")
|
|
||||||
empty_prompt = chat_format.render([user_msg], add_generation_prompt=True).strip()
|
|
||||||
planted_prompt = chat_format.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
|
|
||||||
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
|
||||||
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
|
||||||
|
|
||||||
allow_parallel_calls = False
|
|
||||||
|
|
||||||
def strip_suffix(s: str) -> str:
|
|
||||||
if s.endswith(suffix):
|
|
||||||
return s[:-len(suffix)]
|
|
||||||
else:
|
|
||||||
sys.stderr.write(f"Expected suffix ({suffix}) not found: {s}\n")
|
|
||||||
return s
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
if _outputs_tool_call_tags(chat_format.tool_style):
|
|
||||||
|
|
||||||
escapes_underscores = chat_format.tool_style != ToolsPromptStyle.TOOLS_HERMES_2_PRO
|
|
||||||
|
|
||||||
tool_rules = [
|
|
||||||
converter.visit(
|
|
||||||
dict(
|
|
||||||
type="object",
|
|
||||||
properties=dict(
|
|
||||||
name=dict(type="string", pattern='^' + tool.function.name.replace('_', f'\\?_') + '$') if escapes_underscores \
|
|
||||||
else dict(const=tool.function.name),
|
|
||||||
arguments=tool.function.parameters,
|
|
||||||
),
|
|
||||||
required=['name', 'arguments']
|
|
||||||
),
|
|
||||||
f'{tool.function.name}-tool-call'
|
|
||||||
)
|
|
||||||
for tool in tools
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_literal(s: str) -> str:
|
|
||||||
if escapes_underscores:
|
|
||||||
return ' "\\\\"? "_" '.join((converter._format_literal(part) for part in s.split('_')))
|
|
||||||
else:
|
|
||||||
return converter._format_literal(s)
|
|
||||||
|
|
||||||
tool_call_rule = converter._add_rule(
|
|
||||||
'tool_call',
|
|
||||||
format_literal("<tool_call>") + " space (" +
|
|
||||||
' | '.join(tool_rules) +
|
|
||||||
") space " + format_literal("</tool_call>"))# + ' space')
|
|
||||||
|
|
||||||
# Ideally we'd want a negative lookahead of /<tool\\?_call>/, but it's just too hard to express in GBNF for now.
|
|
||||||
# So we just over-constrain the content rule to not contain literals dangerously getting close to <tool_call>
|
|
||||||
content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "<t" [^o<]')
|
|
||||||
# content_rule = converter._add_rule('content', converter.not_literal('<tool_call>'))
|
|
||||||
converter._add_rule(
|
|
||||||
'root',
|
|
||||||
# tool_call_rule)
|
|
||||||
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \
|
|
||||||
else f'{content_rule}* {tool_call_rule}?')
|
|
||||||
|
|
||||||
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
|
|
||||||
# # OR a tool-call message respecting the schema of any of the tools
|
|
||||||
# converter._add_rule(
|
|
||||||
# "root",
|
|
||||||
# converter._format_literal(prefix) + " (" +
|
|
||||||
# (response_rule or converter.not_literal("<tool_call>")) + " | " +
|
|
||||||
# converter._format_literal("<tool_call>") + " (" +
|
|
||||||
# ' | '.join(tool_rules) +
|
|
||||||
# ") " + converter._format_literal("</tool_call>") +
|
|
||||||
# ")") # + converter._format_literal(suffix))
|
|
||||||
|
|
||||||
@typechecked
|
|
||||||
def parse(s: str) -> Optional[Message]:
|
|
||||||
s = strip_suffix(s)
|
|
||||||
|
|
||||||
if r'<tool\_call>' in s:
|
|
||||||
# Some weird escaping of underscores is happening w/ Mixtral 8x7B Instruct
|
|
||||||
s = s.replace(r'\_', '_')
|
|
||||||
|
|
||||||
parts = _tool_call_re.split(s)
|
|
||||||
if len(parts) == 1:
|
|
||||||
return Message(role="assistant", content=s)
|
|
||||||
else:
|
|
||||||
content = []
|
|
||||||
tool_calls = []
|
|
||||||
for i, part in enumerate(parts):
|
|
||||||
if i % 2 == 0:
|
|
||||||
content.append(part)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
fc = json.loads(part)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f'Failed to parse tool call as JSON: {part}\nFull string: {s}')
|
|
||||||
tool_calls.append(
|
|
||||||
ToolCall(
|
|
||||||
id=gen_callid(),
|
|
||||||
function=FunctionCall(**fc)))
|
|
||||||
|
|
||||||
content = '\n'.join(content).strip()
|
|
||||||
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
|
|
||||||
|
|
||||||
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
|
|
||||||
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
|
|
||||||
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
|
|
||||||
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
|
|
||||||
# return None
|
|
||||||
# else:
|
|
||||||
# return Message(role="assistant", content=s)
|
|
||||||
|
|
||||||
return (converter.format_grammar(), parse)
|
|
||||||
|
|
||||||
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
|
||||||
# Only allowing a single tool call at a time for now.
|
|
||||||
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal
|
|
||||||
|
|
||||||
tool_rules = [
|
|
||||||
converter._add_rule(
|
|
||||||
tool.function.name + '-call',
|
|
||||||
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
|
|
||||||
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
|
|
||||||
converter._format_literal('\n'))
|
|
||||||
# converter.visit(
|
|
||||||
# dict(
|
|
||||||
# type="object",
|
|
||||||
# properties=dict(
|
|
||||||
# name=dict(const=tool.function.name),
|
|
||||||
# arguments=tool.function.parameters,
|
|
||||||
# ),
|
|
||||||
# required=['name', 'arguments']
|
|
||||||
# ),
|
|
||||||
# f'{tool.function.name}-tool-call'
|
|
||||||
# )
|
|
||||||
for i, tool in enumerate(tools)
|
|
||||||
]
|
|
||||||
|
|
||||||
not_from_rule = converter._add_rule('not_from', converter.not_literal("<|from|>"))
|
|
||||||
content_without_start_rule = converter._add_rule(
|
|
||||||
'content_without_start',
|
|
||||||
converter._format_literal("all\n<|content|>") + ' ' + not_from_rule + '*')
|
|
||||||
start_rule = converter._add_rule('start', converter._format_literal('<|from|>assistant\n<|recipient|>'))
|
|
||||||
content_rule = converter._add_rule('content', start_rule + ' ' + content_without_start_rule)
|
|
||||||
tool_call_without_start_rule = converter._add_rule(
|
|
||||||
'tool_call_without_start',
|
|
||||||
' | '.join(tool_rules))
|
|
||||||
# + ' ' +
|
|
||||||
# converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
|
|
||||||
tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
|
|
||||||
# converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
|
|
||||||
converter._add_rule(
|
|
||||||
'root',
|
|
||||||
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
|
|
||||||
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \
|
|
||||||
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
|
|
||||||
|
|
||||||
# converter._add_rule(
|
|
||||||
# "root",
|
|
||||||
# converter._format_literal(prefix) + " (" +
|
|
||||||
# (response_rule or converter.not_literal("<|recipient|>")) + " | " +
|
|
||||||
# (' | '.join(
|
|
||||||
# converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
|
|
||||||
# converter.visit(tool.function.parameters, tool.function.name + '-args')
|
|
||||||
# for tool in tools
|
|
||||||
# )) +
|
|
||||||
# ") " +
|
|
||||||
# ")") # + converter._format_literal(suffix))
|
|
||||||
|
|
||||||
@typechecked
|
|
||||||
def parse(s: str) -> Optional[Message]:
|
|
||||||
s = strip_suffix(s)
|
|
||||||
|
|
||||||
parts = _recipient_content_re.split(s)
|
|
||||||
if len(parts) == 1:
|
|
||||||
return Message(role="assistant", content=s)
|
|
||||||
else:
|
|
||||||
text_content = []
|
|
||||||
tool_calls: list[ToolCall] = []
|
|
||||||
for i in range((len(parts) - 1) // 3):
|
|
||||||
assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}'
|
|
||||||
recipient = parts[i * 3 + 1].strip()
|
|
||||||
content = parts[i * 3 + 2]
|
|
||||||
if recipient == 'all':
|
|
||||||
text_content.append(content)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
arguments = json.loads(content)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f'Failed to parse tool call content as JSON: {content}')
|
|
||||||
tool_calls.append(
|
|
||||||
ToolCall(
|
|
||||||
id=gen_callid(),
|
|
||||||
function=FunctionCall(name=recipient, arguments=arguments)))
|
|
||||||
|
|
||||||
|
|
||||||
assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}'
|
|
||||||
|
|
||||||
content = '\n'.join(text_content).strip()
|
|
||||||
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
|
|
||||||
|
|
||||||
return (converter.format_grammar(), parse)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}")
|
|
||||||
|
|
||||||
elif response_schema:
|
|
||||||
converter._add_rule("root", response_rule + ' ' + converter._format_literal(suffix))
|
|
||||||
|
|
||||||
@typechecked
|
|
||||||
def parse(s: str) -> Optional[Message]:
|
|
||||||
s = strip_suffix(s)
|
|
||||||
return Message(role="assistant", content=s)
|
|
||||||
|
|
||||||
return (converter.format_grammar(), parse)
|
|
||||||
|
|
||||||
else:
|
|
||||||
converter._add_rule("root", converter._format_literal(prefix) + ' ' + converter._format_literal(suffix))
|
|
||||||
|
|
||||||
@typechecked
|
|
||||||
def parse(s: str) -> Optional[Message]:
|
|
||||||
s = strip_suffix(s)
|
|
||||||
return Message(role="assistant", content=s)
|
|
||||||
|
|
||||||
return (None, parse)
|
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,14 @@ import json, sys, subprocess, atexit
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
|
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
|
||||||
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
||||||
from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage
|
from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage
|
||||||
from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt
|
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, get_chat_handler, ChatHandler
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
@ -38,8 +40,8 @@ def main(
|
||||||
|
|
||||||
metadata = GGUFKeyValues(model)
|
metadata = GGUFKeyValues(model)
|
||||||
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
|
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
|
||||||
chat_format = ChatFormat.from_gguf(metadata)
|
chat_template = ChatTemplate.from_gguf(metadata)
|
||||||
# print(chat_format)
|
# print(chat_template)
|
||||||
|
|
||||||
if not cpp_server_endpoint:
|
if not cpp_server_endpoint:
|
||||||
sys.stderr.write(f"# Starting C++ server with model {model} on {cpp_server_host}:{cpp_server_port}\n")
|
sys.stderr.write(f"# Starting C++ server with model {model} on {cpp_server_host}:{cpp_server_port}\n")
|
||||||
|
@ -69,18 +71,17 @@ def main(
|
||||||
else:
|
else:
|
||||||
response_schema = None
|
response_schema = None
|
||||||
|
|
||||||
|
chat_handler = get_chat_handler(ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools))
|
||||||
|
|
||||||
messages = chat_request.messages
|
messages = chat_request.messages
|
||||||
if chat_request.tools:
|
if chat_handler.output_format_prompt:
|
||||||
messages = chat_format.add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
|
messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
|
||||||
|
|
||||||
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
|
prompt = chat_template.render(messages, add_generation_prompt=True)
|
||||||
|
|
||||||
# TODO: Test whether the template supports formatting tool_calls
|
|
||||||
|
|
||||||
prompt = chat_format.render(messages, add_generation_prompt=True)
|
|
||||||
|
|
||||||
|
sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n')
|
||||||
sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n')
|
sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n')
|
||||||
sys.stderr.write(f'\n# GRAMMAR:\n\n{grammar}\n\n')
|
sys.stderr.write(f'\n# GRAMMAR:\n\n{chat_handler.grammar}\n\n')
|
||||||
|
|
||||||
data = LlamaCppServerCompletionRequest(
|
data = LlamaCppServerCompletionRequest(
|
||||||
**{
|
**{
|
||||||
|
@ -94,9 +95,10 @@ def main(
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
grammar=grammar,
|
grammar=chat_handler.grammar,
|
||||||
).model_dump()
|
).model_dump()
|
||||||
sys.stderr.write(json.dumps(data, indent=2) + "\n")
|
# sys.stderr.write(json.dumps(data, indent=2) + "\n")
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{cpp_server_endpoint}/completions",
|
f"{cpp_server_endpoint}/completions",
|
||||||
|
@ -116,7 +118,7 @@ def main(
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
# print(json.dumps(result.get('content'), indent=2))
|
# print(json.dumps(result.get('content'), indent=2))
|
||||||
message = parser(result["content"])
|
message = chat_handler.parse(result["content"])
|
||||||
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
|
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
|
||||||
|
|
||||||
prompt_tokens=result['timings']['prompt_n']
|
prompt_tokens=result['timings']['prompt_n']
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue