server.py: kinda api-compliant output, disabled grammar

This commit is contained in:
ochafik 2024-03-27 01:50:26 +00:00
parent 8afd4de17b
commit aa9605c751
5 changed files with 136 additions and 33 deletions

View file

@ -1,10 +1,15 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Literal, Optional, Union
from pydantic import BaseModel, Json from pydantic import BaseModel, Json
class ToolCall(BaseModel): class FunctionCall(BaseModel):
name: str name: str
arguments: Dict[str, Any] arguments: Dict[str, Any]
class ToolCall(BaseModel):
id: Optional[str] = None
type: Literal["function"] = "function"
function: FunctionCall
class Message(BaseModel): class Message(BaseModel):
role: str role: str
content: Optional[str] content: Optional[str]
@ -30,3 +35,23 @@ class ChatCompletionRequest(BaseModel):
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
temperature: float = 1.0 temperature: float = 1.0
stream: bool = False stream: bool = False
class Choice(BaseModel):
index: int
message: Message
logprobs: Optional[Json] = None
finish_reason: Union[Literal["stop"], Literal["tool_calls"]]
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"]
created: int
model: str
choices: list[Choice]
usage: Usage
system_fingerprint: str

View file

@ -2,13 +2,14 @@ from enum import Enum
import jinja2 import jinja2
import json import json
from pathlib import Path from pathlib import Path
import sys import random
import re import re
import sys
from typing import Optional, Tuple, Callable from typing import Optional, Tuple, Callable
from typeguard import typechecked from typeguard import typechecked
from examples.json_schema_to_grammar import SchemaConverter from examples.json_schema_to_grammar import SchemaConverter
from examples.openai.api import Tool, Message from examples.openai.api import Tool, Message, FunctionCall, ToolCall
from examples.openai.gguf_kvs import GGUFKeyValues, Keys from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.ts_converter import SchemaToTypeScriptConverter from examples.openai.ts_converter import SchemaToTypeScriptConverter
@ -42,7 +43,7 @@ class ChatFormat:
(i, m) = system_message (i, m) = system_message
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:] return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
else: else:
return [Message(role="system", content=system_prompt)] + messages return [system_prompt] + messages
@staticmethod @staticmethod
def from_gguf(metadata: GGUFKeyValues): def from_gguf(metadata: GGUFKeyValues):
@ -69,7 +70,7 @@ class ChatFormat:
i += 1 i += 1
# print(f'new_messages={json.dumps(new_messages, indent=2)}') # print(f'new_messages={json.dumps(new_messages, indent=2)}')
messages = new_messages messages = new_messages
print(f'messages={messages}') # print(f'messages={messages}')
return self.template.render( return self.template.render(
messages=messages, messages=messages,
@ -175,10 +176,15 @@ def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
ToolsPromptStyle.TOOLS_HERMES_2_PRO, ToolsPromptStyle.TOOLS_HERMES_2_PRO,
) )
_tool_call_re = re.compile('<tool_call>(.*?)</tool_call>', re.DOTALL) _tool_call_re = re.compile(
'<tool_call>(.*?)</tool_call>', re.DOTALL)
_recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL)
def gen_callid():
return f'call_{random.randint(0, 1000000)}'
@typechecked @typechecked
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]: def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[list[Message]]]]:
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
@ -191,6 +197,13 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}" assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter) [prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
def strip_suffix(s: str) -> str:
if s.endswith(suffix):
return s[:-len(suffix)]
else:
print(f"Expected suffix ({suffix}) not found: {s}")
return s
if tools: if tools:
if _outputs_tool_call_tags(chat_format.tool_style): if _outputs_tool_call_tags(chat_format.tool_style):
tool_rules = [ tool_rules = [
@ -221,6 +234,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked @typechecked
def parse(s: str) -> Optional[Message]: def parse(s: str) -> Optional[Message]:
s = strip_suffix(s)
# ls = s.lstrip() # ls = s.lstrip()
parts = _tool_call_re.split(s) parts = _tool_call_re.split(s)
if len(parts) == 1: if len(parts) == 1:
@ -232,11 +247,22 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
if i % 2 == 0: if i % 2 == 0:
content.append(part) content.append(part)
else: else:
tool_calls.append(json.loads(part)) tool_calls.append(
ToolCall(
id=gen_callid(),
function=FunctionCall(**json.loads(part))))
content = ''.join(content).strip() content = ''.join(content).strip()
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls) return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls)
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
# return None
# else:
# return Message(role="assistant", content=s)
return (converter.format_grammar(), parse) return (converter.format_grammar(), parse)
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2: elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
@ -256,7 +282,30 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked @typechecked
def parse(s: str) -> Optional[Message]: def parse(s: str) -> Optional[Message]:
raise NotImplementedError(f'TODO: parse tool_style {chat_format.tool_style}: {s}') s = strip_suffix(s)
parts = _recipient_content_re.split(s)
if len(parts) == 1:
return Message(role="assistant", content=s)
else:
text_content = []
tool_calls: list[ToolCall] = []
for i in range((len(parts) - 1) // 3):
assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}'
recipient = parts[i * 3 + 1].strip()
content = parts[i * 3 + 2]
if recipient == 'all':
text_content.append(content)
else:
tool_calls.append(
ToolCall(
id=gen_callid(),
function=FunctionCall(name=recipient, arguments=json.loads(content))))
assert parts[-1].strip() == '', f'Unexpected content after tool calls: {parts[-1]}'
content = '\n'.join(text_content).strip()
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls if tool_calls else None)
return (converter.format_grammar(), parse) return (converter.format_grammar(), parse)
@ -265,8 +314,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked @typechecked
def parse(s: str) -> Optional[Message]: def parse(s: str) -> Optional[Message]:
if response_rule.endswith(suffix): s = strip_suffix(s)
return Message(role="assistant", content=s[:-len(suffix)]) return Message(role="assistant", content=s)
return (converter.format_grammar(), parse) return (converter.format_grammar(), parse)
@ -275,9 +324,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked @typechecked
def parse(s: str) -> Optional[Message]: def parse(s: str) -> Optional[Message]:
if s.endswith(suffix): s = strip_suffix(s)
return Message(role="assistant", content=s[:-len(suffix)]) return Message(role="assistant", content=s)
return None
return (None, parse) return (None, parse)

View file

@ -3,22 +3,27 @@
import json, sys, subprocess, atexit import json, sys, subprocess, atexit
from pathlib import Path from pathlib import Path
import time
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
from examples.openai.gguf_kvs import GGUFKeyValues, Keys from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.api import Message, ChatCompletionRequest from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage
from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
import httpx import httpx
import random
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
from typing import Annotated, Optional from typing import Annotated, Optional
import typer import typer
from typeguard import typechecked from typeguard import typechecked
def generate_id(prefix):
return f"{prefix}{random.randint(0, 1 << 32)}"
def main( def main(
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"), # model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
@ -68,17 +73,13 @@ def main(
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema) (grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
if chat_format.strict_user_assistant_alternation:
print("TODO: merge system messages into user messages")
# new_messages = []
# TODO: Test whether the template supports formatting tool_calls # TODO: Test whether the template supports formatting tool_calls
prompt = chat_format.render(messages, add_generation_prompt=True) prompt = chat_format.render(messages, add_generation_prompt=True)
print(json.dumps(dict( print(json.dumps(dict(
stream=chat_request.stream, stream=chat_request.stream,
prompt=prompt, prompt=prompt,
grammar=grammar, # grammar=grammar,
), indent=2)) ), indent=2))
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
@ -87,7 +88,7 @@ def main(
prompt=prompt, prompt=prompt,
stream=chat_request.stream, stream=chat_request.stream,
n_predict=300, n_predict=300,
grammar=grammar, # grammar=grammar,
).model_dump(), ).model_dump(),
headers=headers, headers=headers,
timeout=None) timeout=None)
@ -98,11 +99,35 @@ def main(
return StreamingResponse(generate_chunks(response), media_type="text/event-stream") return StreamingResponse(generate_chunks(response), media_type="text/event-stream")
else: else:
result = response.json() result = response.json()
if 'content' not in result:
# print(json.dumps(result, indent=2))
return JSONResponse(result)
print(json.dumps(result, indent=2)) print(json.dumps(result, indent=2))
# print(json.dumps(result.get('content'), indent=2))
message = parser(result["content"]) message = parser(result["content"])
assert message is not None, f"Failed to parse response: {response.text}" assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
return JSONResponse(message.model_dump())
# return JSONResponse(response.json()) prompt_tokens=result['timings']['prompt_n']
completion_tokens=result['timings']['predicted_n']
return JSONResponse(ChatCompletionResponse(
id=generate_id('chatcmpl-'),
object="chat.completion",
created=int(time.time()),
model=chat_request.model,
choices=[Choice(
index=0,
message=message,
finish_reason="stop" if message.tool_calls is None else "tool_calls",
)],
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
system_fingerprint='...'
).model_dump())
async def generate_chunks(response): async def generate_chunks(response):
async for chunk in response.aiter_bytes(): async for chunk in response.aiter_bytes():

View file

@ -12,7 +12,9 @@ function cleanup() {
trap cleanup EXIT trap cleanup EXIT
echo "# Starting the server" echo "# Starting the server"
python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
python -m examples.openai --model ~/AI/Models/functionary-medium-v2.2.q4_0.gguf &
# python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
# python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf & # python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
SERVER_PID=$! SERVER_PID=$!
@ -73,8 +75,8 @@ curl http://localhost:8080/v1/chat/completions \
} }
}], }],
"messages": [ "messages": [
{"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."}, {"role": "user", "content": "I live in the UK. what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
{"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
] ]
}' }'
# {"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},

View file

@ -1,5 +1,5 @@
from typing import Any, List, Set, Tuple, Union from typing import Any, List, Set, Tuple, Union
from jsonargparse import CLI import json
class SchemaToTypeScriptConverter: class SchemaToTypeScriptConverter:
# TODO: comments for arguments! # TODO: comments for arguments!
@ -14,11 +14,14 @@ class SchemaToTypeScriptConverter:
# // where to get weather. # // where to get weather.
# location: string, # location: string,
# }) => any; # }) => any;
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], additional_properties: Union[bool, Any]):
return "{" + ', '.join( return "{" + ', '.join([
f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}' f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
for prop_name, prop_schema in properties for prop_name, prop_schema in properties
) + "}" ] + (
[f"[key: string]: {self.visit(additional_properties)}"]
if additional_properties is not None else []
)) + "}"
def visit(self, schema: dict): def visit(self, schema: dict):
def print_constant(v): def print_constant(v):