server.py: make tools work w/ mixtral-8x7b-instruct

This commit is contained in:
ochafik 2024-03-27 00:12:14 +00:00
parent d5d9993679
commit 8afd4de17b
4 changed files with 61 additions and 26 deletions

View file

@ -56,6 +56,10 @@ The new examples/openai/server.py:
## TODO ## TODO
- Support tool result messages
- Reactor /
- Embedding endpoint w/ distinct server subprocess - Embedding endpoint w/ distinct server subprocess
- Automatic/manual session caching - Automatic/manual session caching

View file

@ -3,6 +3,7 @@ import jinja2
import json import json
from pathlib import Path from pathlib import Path
import sys import sys
import re
from typing import Optional, Tuple, Callable from typing import Optional, Tuple, Callable
from typeguard import typechecked from typeguard import typechecked
@ -15,6 +16,7 @@ from examples.openai.ts_converter import SchemaToTypeScriptConverter
def raise_exception(msg: str): def raise_exception(msg: str):
raise Exception(msg) raise Exception(msg)
@typechecked
class ChatFormat: class ChatFormat:
def __init__(self, template: str, eos_token: str, bos_token: str): def __init__(self, template: str, eos_token: str, bos_token: str):
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True) env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
@ -32,14 +34,43 @@ class ChatFormat:
def __str__(self): def __str__(self):
return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})" return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})"
def add_system_prompt(self, messages: list[Message], system_prompt: Message) -> list[Message]:
assert system_prompt.role == "system"
# TODO: add to last system message, or create a new one just before the last user message
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
if system_message is not None:
(i, m) = system_message
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
else:
return [Message(role="system", content=system_prompt)] + messages
@staticmethod @staticmethod
def from_gguf(metadata: GGUFKeyValues): def from_gguf(metadata: GGUFKeyValues):
tokens = metadata[Keys.Tokenizer.LIST]
return ChatFormat( return ChatFormat(
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE], template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
bos_token = metadata[Keys.Tokenizer.BOS_ID], bos_token = tokens[metadata[Keys.Tokenizer.BOS_ID]],
eos_token = metadata[Keys.Tokenizer.EOS_ID]) eos_token = tokens[metadata[Keys.Tokenizer.EOS_ID]])
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
if self.strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
new_messages=[]
i = 0
n = len(messages)
while i < n:
if messages[i].role == 'system':
assert messages[i+1].role == 'user'
new_messages.append(Message(
role="user",
content=f'[SYS]{messages[i].content}[/SYS]\n{messages[i+1].content}'))
i += 2
else:
new_messages.append(messages[i])
i += 1
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
messages = new_messages
print(f'messages={messages}')
def render(self, messages: list[dict], add_generation_prompt: bool, omit_bos: bool = False):
return self.template.render( return self.template.render(
messages=messages, messages=messages,
eos_token=self.eos_token, eos_token=self.eos_token,
@ -144,6 +175,8 @@ def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
ToolsPromptStyle.TOOLS_HERMES_2_PRO, ToolsPromptStyle.TOOLS_HERMES_2_PRO,
) )
_tool_call_re = re.compile('<tool_call>(.*?)</tool_call>', re.DOTALL)
@typechecked @typechecked
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]: def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]:
@ -152,8 +185,9 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
response_rule = converter.visit(response_schema, "response") if response_schema else None response_rule = converter.visit(response_schema, "response") if response_schema else None
delimiter = '<%$[SAMPLE]$%>' delimiter = '<%$[SAMPLE]$%>'
empty_prompt = chat_format.render([], add_generation_prompt=True).strip() user_msg = Message(role="user", content="Hey")
planted_prompt = chat_format.render([{"role": "assistant", "content": delimiter}], add_generation_prompt=False).strip() empty_prompt = chat_format.render([user_msg], add_generation_prompt=True).strip()
planted_prompt = chat_format.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}" assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter) [prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
@ -187,14 +221,21 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked @typechecked
def parse(s: str) -> Optional[Message]: def parse(s: str) -> Optional[Message]:
ls = s.lstrip() # ls = s.lstrip()
if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'): parts = _tool_call_re.split(s)
if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix): if len(parts) == 1:
tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
return None
else:
return Message(role="assistant", content=s) return Message(role="assistant", content=s)
else:
content = []
tool_calls = []
for i, part in enumerate(parts):
if i % 2 == 0:
content.append(part)
else:
tool_calls.append(json.loads(part))
content = ''.join(content).strip()
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls)
return (converter.format_grammar(), parse) return (converter.format_grammar(), parse)

View file

@ -19,17 +19,6 @@ from typing import Annotated, Optional
import typer import typer
from typeguard import typechecked from typeguard import typechecked
@typechecked
def _add_system_prompt(messages: list[Message], system_prompt: Message) -> list[Message]:
assert system_prompt.role == "system"
# TODO: add to last system message, or create a new one just before the last user message
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
if system_message is not None:
(i, m) = system_message
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
else:
return [Message(role="system", content=system_prompt)] + messages
def main( def main(
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"), # model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
@ -75,7 +64,7 @@ def main(
messages = chat_request.messages messages = chat_request.messages
if chat_request.tools: if chat_request.tools:
messages = _add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools)) messages = chat_format.add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema) (grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)

View file

@ -12,7 +12,8 @@ function cleanup() {
trap cleanup EXIT trap cleanup EXIT
echo "# Starting the server" echo "# Starting the server"
python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf & python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
# python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
SERVER_PID=$! SERVER_PID=$!
sleep 5 sleep 5
@ -73,7 +74,7 @@ curl http://localhost:8080/v1/chat/completions \
}], }],
"messages": [ "messages": [
{"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."}, {"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
{"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days. Give the temperatyre in celsius for both locations."} {"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
] ]
}' }'