server.py: make tools work w/ mixtral-8x7b-instruct
This commit is contained in:
parent
d5d9993679
commit
8afd4de17b
4 changed files with 61 additions and 26 deletions
|
@ -56,6 +56,10 @@ The new examples/openai/server.py:
|
|||
|
||||
## TODO
|
||||
|
||||
- Support tool result messages
|
||||
|
||||
- Reactor /
|
||||
|
||||
- Embedding endpoint w/ distinct server subprocess
|
||||
|
||||
- Automatic/manual session caching
|
||||
|
|
|
@ -3,6 +3,7 @@ import jinja2
|
|||
import json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import re
|
||||
from typing import Optional, Tuple, Callable
|
||||
from typeguard import typechecked
|
||||
|
||||
|
@ -15,6 +16,7 @@ from examples.openai.ts_converter import SchemaToTypeScriptConverter
|
|||
def raise_exception(msg: str):
|
||||
raise Exception(msg)
|
||||
|
||||
@typechecked
|
||||
class ChatFormat:
|
||||
def __init__(self, template: str, eos_token: str, bos_token: str):
|
||||
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
|
||||
|
@ -32,14 +34,43 @@ class ChatFormat:
|
|||
def __str__(self):
|
||||
return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})"
|
||||
|
||||
def add_system_prompt(self, messages: list[Message], system_prompt: Message) -> list[Message]:
|
||||
assert system_prompt.role == "system"
|
||||
# TODO: add to last system message, or create a new one just before the last user message
|
||||
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
|
||||
if system_message is not None:
|
||||
(i, m) = system_message
|
||||
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
|
||||
else:
|
||||
return [Message(role="system", content=system_prompt)] + messages
|
||||
|
||||
@staticmethod
|
||||
def from_gguf(metadata: GGUFKeyValues):
|
||||
tokens = metadata[Keys.Tokenizer.LIST]
|
||||
return ChatFormat(
|
||||
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
|
||||
bos_token = metadata[Keys.Tokenizer.BOS_ID],
|
||||
eos_token = metadata[Keys.Tokenizer.EOS_ID])
|
||||
bos_token = tokens[metadata[Keys.Tokenizer.BOS_ID]],
|
||||
eos_token = tokens[metadata[Keys.Tokenizer.EOS_ID]])
|
||||
|
||||
def render(self, messages: list[dict], add_generation_prompt: bool, omit_bos: bool = False):
|
||||
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
||||
if self.strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
|
||||
new_messages=[]
|
||||
i = 0
|
||||
n = len(messages)
|
||||
while i < n:
|
||||
if messages[i].role == 'system':
|
||||
assert messages[i+1].role == 'user'
|
||||
new_messages.append(Message(
|
||||
role="user",
|
||||
content=f'[SYS]{messages[i].content}[/SYS]\n{messages[i+1].content}'))
|
||||
i += 2
|
||||
else:
|
||||
new_messages.append(messages[i])
|
||||
i += 1
|
||||
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
|
||||
messages = new_messages
|
||||
print(f'messages={messages}')
|
||||
|
||||
return self.template.render(
|
||||
messages=messages,
|
||||
eos_token=self.eos_token,
|
||||
|
@ -144,6 +175,8 @@ def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
|
|||
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
|
||||
)
|
||||
|
||||
_tool_call_re = re.compile('<tool_call>(.*?)</tool_call>', re.DOTALL)
|
||||
|
||||
@typechecked
|
||||
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]:
|
||||
|
||||
|
@ -152,8 +185,9 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
|||
response_rule = converter.visit(response_schema, "response") if response_schema else None
|
||||
|
||||
delimiter = '<%$[SAMPLE]$%>'
|
||||
empty_prompt = chat_format.render([], add_generation_prompt=True).strip()
|
||||
planted_prompt = chat_format.render([{"role": "assistant", "content": delimiter}], add_generation_prompt=False).strip()
|
||||
user_msg = Message(role="user", content="Hey")
|
||||
empty_prompt = chat_format.render([user_msg], add_generation_prompt=True).strip()
|
||||
planted_prompt = chat_format.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
|
||||
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
||||
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
||||
|
||||
|
@ -187,14 +221,21 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
|||
|
||||
@typechecked
|
||||
def parse(s: str) -> Optional[Message]:
|
||||
ls = s.lstrip()
|
||||
if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
|
||||
if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
|
||||
tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
|
||||
return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
|
||||
return None
|
||||
else:
|
||||
# ls = s.lstrip()
|
||||
parts = _tool_call_re.split(s)
|
||||
if len(parts) == 1:
|
||||
return Message(role="assistant", content=s)
|
||||
else:
|
||||
content = []
|
||||
tool_calls = []
|
||||
for i, part in enumerate(parts):
|
||||
if i % 2 == 0:
|
||||
content.append(part)
|
||||
else:
|
||||
tool_calls.append(json.loads(part))
|
||||
|
||||
content = ''.join(content).strip()
|
||||
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls)
|
||||
|
||||
return (converter.format_grammar(), parse)
|
||||
|
||||
|
|
|
@ -19,17 +19,6 @@ from typing import Annotated, Optional
|
|||
import typer
|
||||
from typeguard import typechecked
|
||||
|
||||
@typechecked
|
||||
def _add_system_prompt(messages: list[Message], system_prompt: Message) -> list[Message]:
|
||||
assert system_prompt.role == "system"
|
||||
# TODO: add to last system message, or create a new one just before the last user message
|
||||
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
|
||||
if system_message is not None:
|
||||
(i, m) = system_message
|
||||
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
|
||||
else:
|
||||
return [Message(role="system", content=system_prompt)] + messages
|
||||
|
||||
def main(
|
||||
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
|
||||
|
@ -75,7 +64,7 @@ def main(
|
|||
|
||||
messages = chat_request.messages
|
||||
if chat_request.tools:
|
||||
messages = _add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
|
||||
messages = chat_format.add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
|
||||
|
||||
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
|
||||
|
||||
|
|
|
@ -12,7 +12,8 @@ function cleanup() {
|
|||
trap cleanup EXIT
|
||||
|
||||
echo "# Starting the server"
|
||||
python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
|
||||
python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
|
||||
# python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
|
||||
SERVER_PID=$!
|
||||
|
||||
sleep 5
|
||||
|
@ -73,7 +74,7 @@ curl http://localhost:8080/v1/chat/completions \
|
|||
}],
|
||||
"messages": [
|
||||
{"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
|
||||
{"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days. Give the temperatyre in celsius for both locations."}
|
||||
{"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
|
||||
]
|
||||
}'
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue