server.py: make tools work w/ mixtral-8x7b-instruct
This commit is contained in:
parent
d5d9993679
commit
8afd4de17b
4 changed files with 61 additions and 26 deletions
|
@ -56,6 +56,10 @@ The new examples/openai/server.py:
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
|
- Support tool result messages
|
||||||
|
|
||||||
|
- Reactor /
|
||||||
|
|
||||||
- Embedding endpoint w/ distinct server subprocess
|
- Embedding endpoint w/ distinct server subprocess
|
||||||
|
|
||||||
- Automatic/manual session caching
|
- Automatic/manual session caching
|
||||||
|
|
|
@ -3,6 +3,7 @@ import jinja2
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
|
import re
|
||||||
from typing import Optional, Tuple, Callable
|
from typing import Optional, Tuple, Callable
|
||||||
from typeguard import typechecked
|
from typeguard import typechecked
|
||||||
|
|
||||||
|
@ -15,6 +16,7 @@ from examples.openai.ts_converter import SchemaToTypeScriptConverter
|
||||||
def raise_exception(msg: str):
|
def raise_exception(msg: str):
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
|
|
||||||
|
@typechecked
|
||||||
class ChatFormat:
|
class ChatFormat:
|
||||||
def __init__(self, template: str, eos_token: str, bos_token: str):
|
def __init__(self, template: str, eos_token: str, bos_token: str):
|
||||||
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
|
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
|
||||||
|
@ -32,14 +34,43 @@ class ChatFormat:
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})"
|
return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})"
|
||||||
|
|
||||||
|
def add_system_prompt(self, messages: list[Message], system_prompt: Message) -> list[Message]:
|
||||||
|
assert system_prompt.role == "system"
|
||||||
|
# TODO: add to last system message, or create a new one just before the last user message
|
||||||
|
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
|
||||||
|
if system_message is not None:
|
||||||
|
(i, m) = system_message
|
||||||
|
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
|
||||||
|
else:
|
||||||
|
return [Message(role="system", content=system_prompt)] + messages
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_gguf(metadata: GGUFKeyValues):
|
def from_gguf(metadata: GGUFKeyValues):
|
||||||
|
tokens = metadata[Keys.Tokenizer.LIST]
|
||||||
return ChatFormat(
|
return ChatFormat(
|
||||||
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
|
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
|
||||||
bos_token = metadata[Keys.Tokenizer.BOS_ID],
|
bos_token = tokens[metadata[Keys.Tokenizer.BOS_ID]],
|
||||||
eos_token = metadata[Keys.Tokenizer.EOS_ID])
|
eos_token = tokens[metadata[Keys.Tokenizer.EOS_ID]])
|
||||||
|
|
||||||
|
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
||||||
|
if self.strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
|
||||||
|
new_messages=[]
|
||||||
|
i = 0
|
||||||
|
n = len(messages)
|
||||||
|
while i < n:
|
||||||
|
if messages[i].role == 'system':
|
||||||
|
assert messages[i+1].role == 'user'
|
||||||
|
new_messages.append(Message(
|
||||||
|
role="user",
|
||||||
|
content=f'[SYS]{messages[i].content}[/SYS]\n{messages[i+1].content}'))
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
new_messages.append(messages[i])
|
||||||
|
i += 1
|
||||||
|
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
|
||||||
|
messages = new_messages
|
||||||
|
print(f'messages={messages}')
|
||||||
|
|
||||||
def render(self, messages: list[dict], add_generation_prompt: bool, omit_bos: bool = False):
|
|
||||||
return self.template.render(
|
return self.template.render(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
eos_token=self.eos_token,
|
eos_token=self.eos_token,
|
||||||
|
@ -144,6 +175,8 @@ def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
|
||||||
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
|
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_tool_call_re = re.compile('<tool_call>(.*?)</tool_call>', re.DOTALL)
|
||||||
|
|
||||||
@typechecked
|
@typechecked
|
||||||
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]:
|
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]:
|
||||||
|
|
||||||
|
@ -152,8 +185,9 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
response_rule = converter.visit(response_schema, "response") if response_schema else None
|
response_rule = converter.visit(response_schema, "response") if response_schema else None
|
||||||
|
|
||||||
delimiter = '<%$[SAMPLE]$%>'
|
delimiter = '<%$[SAMPLE]$%>'
|
||||||
empty_prompt = chat_format.render([], add_generation_prompt=True).strip()
|
user_msg = Message(role="user", content="Hey")
|
||||||
planted_prompt = chat_format.render([{"role": "assistant", "content": delimiter}], add_generation_prompt=False).strip()
|
empty_prompt = chat_format.render([user_msg], add_generation_prompt=True).strip()
|
||||||
|
planted_prompt = chat_format.render([user_msg, Message(role="assistant", content=delimiter)], add_generation_prompt=False).strip()
|
||||||
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
||||||
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
||||||
|
|
||||||
|
@ -187,14 +221,21 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
|
|
||||||
@typechecked
|
@typechecked
|
||||||
def parse(s: str) -> Optional[Message]:
|
def parse(s: str) -> Optional[Message]:
|
||||||
ls = s.lstrip()
|
# ls = s.lstrip()
|
||||||
if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
|
parts = _tool_call_re.split(s)
|
||||||
if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
|
if len(parts) == 1:
|
||||||
tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
|
|
||||||
return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return Message(role="assistant", content=s)
|
return Message(role="assistant", content=s)
|
||||||
|
else:
|
||||||
|
content = []
|
||||||
|
tool_calls = []
|
||||||
|
for i, part in enumerate(parts):
|
||||||
|
if i % 2 == 0:
|
||||||
|
content.append(part)
|
||||||
|
else:
|
||||||
|
tool_calls.append(json.loads(part))
|
||||||
|
|
||||||
|
content = ''.join(content).strip()
|
||||||
|
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls)
|
||||||
|
|
||||||
return (converter.format_grammar(), parse)
|
return (converter.format_grammar(), parse)
|
||||||
|
|
||||||
|
|
|
@ -19,17 +19,6 @@ from typing import Annotated, Optional
|
||||||
import typer
|
import typer
|
||||||
from typeguard import typechecked
|
from typeguard import typechecked
|
||||||
|
|
||||||
@typechecked
|
|
||||||
def _add_system_prompt(messages: list[Message], system_prompt: Message) -> list[Message]:
|
|
||||||
assert system_prompt.role == "system"
|
|
||||||
# TODO: add to last system message, or create a new one just before the last user message
|
|
||||||
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
|
|
||||||
if system_message is not None:
|
|
||||||
(i, m) = system_message
|
|
||||||
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
|
|
||||||
else:
|
|
||||||
return [Message(role="system", content=system_prompt)] + messages
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||||
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
|
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
|
||||||
|
@ -75,7 +64,7 @@ def main(
|
||||||
|
|
||||||
messages = chat_request.messages
|
messages = chat_request.messages
|
||||||
if chat_request.tools:
|
if chat_request.tools:
|
||||||
messages = _add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
|
messages = chat_format.add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
|
||||||
|
|
||||||
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
|
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,8 @@ function cleanup() {
|
||||||
trap cleanup EXIT
|
trap cleanup EXIT
|
||||||
|
|
||||||
echo "# Starting the server"
|
echo "# Starting the server"
|
||||||
python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
|
python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
|
||||||
|
# python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
|
||||||
SERVER_PID=$!
|
SERVER_PID=$!
|
||||||
|
|
||||||
sleep 5
|
sleep 5
|
||||||
|
@ -73,7 +74,7 @@ curl http://localhost:8080/v1/chat/completions \
|
||||||
}],
|
}],
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
|
{"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
|
||||||
{"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days. Give the temperatyre in celsius for both locations."}
|
{"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
|
||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue