server.py: kinda api-compliant output, disabled grammar

This commit is contained in:
ochafik 2024-03-27 01:50:26 +00:00
parent 8afd4de17b
commit aa9605c751
5 changed files with 136 additions and 33 deletions

View file

@ -1,10 +1,15 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Literal, Optional, Union
from pydantic import BaseModel, Json
class ToolCall(BaseModel):
class FunctionCall(BaseModel):
name: str
arguments: Dict[str, Any]
class ToolCall(BaseModel):
id: Optional[str] = None
type: Literal["function"] = "function"
function: FunctionCall
class Message(BaseModel):
role: str
content: Optional[str]
@ -30,3 +35,23 @@ class ChatCompletionRequest(BaseModel):
response_format: Optional[ResponseFormat] = None
temperature: float = 1.0
stream: bool = False
class Choice(BaseModel):
index: int
message: Message
logprobs: Optional[Json] = None
finish_reason: Union[Literal["stop"], Literal["tool_calls"]]
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"]
created: int
model: str
choices: list[Choice]
usage: Usage
system_fingerprint: str

View file

@ -2,13 +2,14 @@ from enum import Enum
import jinja2
import json
from pathlib import Path
import sys
import random
import re
import sys
from typing import Optional, Tuple, Callable
from typeguard import typechecked
from examples.json_schema_to_grammar import SchemaConverter
from examples.openai.api import Tool, Message
from examples.openai.api import Tool, Message, FunctionCall, ToolCall
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.ts_converter import SchemaToTypeScriptConverter
@ -42,7 +43,7 @@ class ChatFormat:
(i, m) = system_message
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
else:
return [Message(role="system", content=system_prompt)] + messages
return [system_prompt] + messages
@staticmethod
def from_gguf(metadata: GGUFKeyValues):
@ -69,7 +70,7 @@ class ChatFormat:
i += 1
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
messages = new_messages
print(f'messages={messages}')
# print(f'messages={messages}')
return self.template.render(
messages=messages,
@ -175,10 +176,15 @@ def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
)
_tool_call_re = re.compile('<tool_call>(.*?)</tool_call>', re.DOTALL)
_tool_call_re = re.compile(
'<tool_call>(.*?)</tool_call>', re.DOTALL)
_recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL)
def gen_callid():
return f'call_{random.randint(0, 1000000)}'
@typechecked
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]:
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[list[Message]]]]:
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
@ -191,6 +197,13 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
def strip_suffix(s: str) -> str:
if s.endswith(suffix):
return s[:-len(suffix)]
else:
print(f"Expected suffix ({suffix}) not found: {s}")
return s
if tools:
if _outputs_tool_call_tags(chat_format.tool_style):
tool_rules = [
@ -221,6 +234,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked
def parse(s: str) -> Optional[Message]:
s = strip_suffix(s)
# ls = s.lstrip()
parts = _tool_call_re.split(s)
if len(parts) == 1:
@ -232,10 +247,21 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
if i % 2 == 0:
content.append(part)
else:
tool_calls.append(json.loads(part))
tool_calls.append(
ToolCall(
id=gen_callid(),
function=FunctionCall(**json.loads(part))))
content = ''.join(content).strip()
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls)
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
# return None
# else:
# return Message(role="assistant", content=s)
return (converter.format_grammar(), parse)
@ -256,7 +282,30 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked
def parse(s: str) -> Optional[Message]:
raise NotImplementedError(f'TODO: parse tool_style {chat_format.tool_style}: {s}')
s = strip_suffix(s)
parts = _recipient_content_re.split(s)
if len(parts) == 1:
return Message(role="assistant", content=s)
else:
text_content = []
tool_calls: list[ToolCall] = []
for i in range((len(parts) - 1) // 3):
assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}'
recipient = parts[i * 3 + 1].strip()
content = parts[i * 3 + 2]
if recipient == 'all':
text_content.append(content)
else:
tool_calls.append(
ToolCall(
id=gen_callid(),
function=FunctionCall(name=recipient, arguments=json.loads(content))))
assert parts[-1].strip() == '', f'Unexpected content after tool calls: {parts[-1]}'
content = '\n'.join(text_content).strip()
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls if tool_calls else None)
return (converter.format_grammar(), parse)
@ -265,8 +314,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked
def parse(s: str) -> Optional[Message]:
if response_rule.endswith(suffix):
return Message(role="assistant", content=s[:-len(suffix)])
s = strip_suffix(s)
return Message(role="assistant", content=s)
return (converter.format_grammar(), parse)
@ -275,9 +324,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked
def parse(s: str) -> Optional[Message]:
if s.endswith(suffix):
return Message(role="assistant", content=s[:-len(suffix)])
return None
s = strip_suffix(s)
return Message(role="assistant", content=s)
return (None, parse)

View file

@ -3,22 +3,27 @@
import json, sys, subprocess, atexit
from pathlib import Path
import time
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.api import Message, ChatCompletionRequest
from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage
from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import httpx
import random
from starlette.responses import StreamingResponse
from typing import Annotated, Optional
import typer
from typeguard import typechecked
def generate_id(prefix):
return f"{prefix}{random.randint(0, 1 << 32)}"
def main(
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
@ -68,17 +73,13 @@ def main(
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
if chat_format.strict_user_assistant_alternation:
print("TODO: merge system messages into user messages")
# new_messages = []
# TODO: Test whether the template supports formatting tool_calls
prompt = chat_format.render(messages, add_generation_prompt=True)
print(json.dumps(dict(
stream=chat_request.stream,
prompt=prompt,
grammar=grammar,
# grammar=grammar,
), indent=2))
async with httpx.AsyncClient() as client:
response = await client.post(
@ -87,7 +88,7 @@ def main(
prompt=prompt,
stream=chat_request.stream,
n_predict=300,
grammar=grammar,
# grammar=grammar,
).model_dump(),
headers=headers,
timeout=None)
@ -98,11 +99,35 @@ def main(
return StreamingResponse(generate_chunks(response), media_type="text/event-stream")
else:
result = response.json()
if 'content' not in result:
# print(json.dumps(result, indent=2))
return JSONResponse(result)
print(json.dumps(result, indent=2))
# print(json.dumps(result.get('content'), indent=2))
message = parser(result["content"])
assert message is not None, f"Failed to parse response: {response.text}"
return JSONResponse(message.model_dump())
# return JSONResponse(response.json())
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
prompt_tokens=result['timings']['prompt_n']
completion_tokens=result['timings']['predicted_n']
return JSONResponse(ChatCompletionResponse(
id=generate_id('chatcmpl-'),
object="chat.completion",
created=int(time.time()),
model=chat_request.model,
choices=[Choice(
index=0,
message=message,
finish_reason="stop" if message.tool_calls is None else "tool_calls",
)],
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
system_fingerprint='...'
).model_dump())
async def generate_chunks(response):
async for chunk in response.aiter_bytes():

View file

@ -12,7 +12,9 @@ function cleanup() {
trap cleanup EXIT
echo "# Starting the server"
python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
python -m examples.openai --model ~/AI/Models/functionary-medium-v2.2.q4_0.gguf &
# python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
# python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
SERVER_PID=$!
@ -73,8 +75,8 @@ curl http://localhost:8080/v1/chat/completions \
}
}],
"messages": [
{"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
{"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
{"role": "user", "content": "I live in the UK. what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
]
}'
# {"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},

View file

@ -1,5 +1,5 @@
from typing import Any, List, Set, Tuple, Union
from jsonargparse import CLI
import json
class SchemaToTypeScriptConverter:
# TODO: comments for arguments!
@ -14,11 +14,14 @@ class SchemaToTypeScriptConverter:
# // where to get weather.
# location: string,
# }) => any;
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
return "{" + ', '.join(
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], additional_properties: Union[bool, Any]):
return "{" + ', '.join([
f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
for prop_name, prop_schema in properties
) + "}"
] + (
[f"[key: string]: {self.visit(additional_properties)}"]
if additional_properties is not None else []
)) + "}"
def visit(self, schema: dict):
def print_constant(v):