server.py: make tools work w/ mixtral-8x7b-instruct

This commit is contained in:
ochafik 2024-03-27 00:12:14 +00:00
parent d5d9993679
commit 8afd4de17b
4 changed files with 61 additions and 26 deletions

View file

@ -56,6 +56,10 @@ The new examples/openai/server.py:
## TODO
- Support tool result messages
- Reactor /
- Embedding endpoint w/ distinct server subprocess
- Automatic/manual session caching

View file

@ -3,6 +3,7 @@ import jinja2
import json
from pathlib import Path
import sys
import re
from typing import Optional, Tuple, Callable
from typeguard import typechecked
@ -15,6 +16,7 @@ from examples.openai.ts_converter import SchemaToTypeScriptConverter
def raise_exception(msg: str):
raise Exception(msg)
@typechecked
class ChatFormat:
def __init__(self, template: str, eos_token: str, bos_token: str):
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
@ -32,14 +34,43 @@ class ChatFormat:
def __str__(self):
return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})"
def add_system_prompt(self, messages: list[Message], system_prompt: Message) -> list[Message]:
assert system_prompt.role == "system"
# TODO: add to last system message, or create a new one just before the last user message
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
if system_message is not None:
(i, m) = system_message
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
else:
return [Message(role="system", content=system_prompt)] + messages
@staticmethod
def from_gguf(metadata: GGUFKeyValues):
tokens = metadata[Keys.Tokenizer.LIST]
return ChatFormat(
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
bos_token = metadata[Keys.Tokenizer.BOS_ID],
eos_token = metadata[Keys.Tokenizer.EOS_ID])
bos_token = tokens[metadata[Keys.Tokenizer.BOS_ID]],
eos_token = tokens[metadata[Keys.Tokenizer.EOS_ID]])
def render(self, messages: list[dict], add_generation_prompt: bool, omit_bos: bool = False):
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
if self.strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
new_messages=[]
i = 0
n = len(messages)
while i < n:
if messages[i].role == 'system':
assert messages[i+1].role == 'user'
new_messages.append(Message(
role="user",
content=f'[SYS]{messages[i].content}[/SYS]\n{messages[i+1].content}'))
i += 2
else:
new_messages.append(messages[i])
i += 1
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
messages = new_messages
print(f'messages={messages}')
return self.template.render(
messages=messages,
eos_token=self.eos_token,
@ -144,6 +175,8 @@ def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
)
_tool_call_re = re.compile('<tool_call>(.*?)</tool_call>', re.DOTALL)
@typechecked
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]:
@ -152,8 +185,9 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
response_rule = converter.visit(response_schema, "response") if response_schema else None
delimiter = '<%$[SAMPLE]$%>'
empty_prompt = chat_format.render([], add_generation_prompt=True).strip()
planted_prompt = chat_format.render([{"role": "assistant", "content": delimiter}], add_generation_prompt=False).strip()
user_msg = Message(role="user", content="Hey")
empty_prompt = chat_format.render([user_msg], add_generation_prompt=True).strip()
planted_prompt = chat_format.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
@ -187,14 +221,21 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked
def parse(s: str) -> Optional[Message]:
ls = s.lstrip()
if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
return None
else:
# ls = s.lstrip()
parts = _tool_call_re.split(s)
if len(parts) == 1:
return Message(role="assistant", content=s)
else:
content = []
tool_calls = []
for i, part in enumerate(parts):
if i % 2 == 0:
content.append(part)
else:
tool_calls.append(json.loads(part))
content = ''.join(content).strip()
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls)
return (converter.format_grammar(), parse)

View file

@ -19,17 +19,6 @@ from typing import Annotated, Optional
import typer
from typeguard import typechecked
@typechecked
def _add_system_prompt(messages: list[Message], system_prompt: Message) -> list[Message]:
assert system_prompt.role == "system"
# TODO: add to last system message, or create a new one just before the last user message
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
if system_message is not None:
(i, m) = system_message
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
else:
return [Message(role="system", content=system_prompt)] + messages
def main(
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
@ -75,7 +64,7 @@ def main(
messages = chat_request.messages
if chat_request.tools:
messages = _add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
messages = chat_format.add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)

View file

@ -12,7 +12,8 @@ function cleanup() {
trap cleanup EXIT
echo "# Starting the server"
python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
# python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
SERVER_PID=$!
sleep 5
@ -73,7 +74,7 @@ curl http://localhost:8080/v1/chat/completions \
}],
"messages": [
{"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
{"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days. Give the temperatyre in celsius for both locations."}
{"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
]
}'