server.py: kinda api-compliant output, disabled grammar
This commit is contained in:
parent
8afd4de17b
commit
aa9605c751
5 changed files with 136 additions and 33 deletions
|
@ -1,10 +1,15 @@
|
|||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Json
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
class FunctionCall(BaseModel):
|
||||
name: str
|
||||
arguments: Dict[str, Any]
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: Optional[str] = None
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionCall
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: Optional[str]
|
||||
|
@ -30,3 +35,23 @@ class ChatCompletionRequest(BaseModel):
|
|||
response_format: Optional[ResponseFormat] = None
|
||||
temperature: float = 1.0
|
||||
stream: bool = False
|
||||
|
||||
class Choice(BaseModel):
|
||||
index: int
|
||||
message: Message
|
||||
logprobs: Optional[Json] = None
|
||||
finish_reason: Union[Literal["stop"], Literal["tool_calls"]]
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["chat.completion"]
|
||||
created: int
|
||||
model: str
|
||||
choices: list[Choice]
|
||||
usage: Usage
|
||||
system_fingerprint: str
|
|
@ -2,13 +2,14 @@ from enum import Enum
|
|||
import jinja2
|
||||
import json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from typing import Optional, Tuple, Callable
|
||||
from typeguard import typechecked
|
||||
|
||||
from examples.json_schema_to_grammar import SchemaConverter
|
||||
from examples.openai.api import Tool, Message
|
||||
from examples.openai.api import Tool, Message, FunctionCall, ToolCall
|
||||
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
||||
from examples.openai.ts_converter import SchemaToTypeScriptConverter
|
||||
|
||||
|
@ -42,7 +43,7 @@ class ChatFormat:
|
|||
(i, m) = system_message
|
||||
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
|
||||
else:
|
||||
return [Message(role="system", content=system_prompt)] + messages
|
||||
return [system_prompt] + messages
|
||||
|
||||
@staticmethod
|
||||
def from_gguf(metadata: GGUFKeyValues):
|
||||
|
@ -69,7 +70,7 @@ class ChatFormat:
|
|||
i += 1
|
||||
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
|
||||
messages = new_messages
|
||||
print(f'messages={messages}')
|
||||
# print(f'messages={messages}')
|
||||
|
||||
return self.template.render(
|
||||
messages=messages,
|
||||
|
@ -175,10 +176,15 @@ def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
|
|||
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
|
||||
)
|
||||
|
||||
_tool_call_re = re.compile('<tool_call>(.*?)</tool_call>', re.DOTALL)
|
||||
|
||||
_tool_call_re = re.compile(
|
||||
'<tool_call>(.*?)</tool_call>', re.DOTALL)
|
||||
_recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL)
|
||||
|
||||
def gen_callid():
|
||||
return f'call_{random.randint(0, 1000000)}'
|
||||
|
||||
@typechecked
|
||||
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]:
|
||||
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[list[Message]]]]:
|
||||
|
||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||
|
||||
|
@ -191,6 +197,13 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
|||
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
||||
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
||||
|
||||
def strip_suffix(s: str) -> str:
|
||||
if s.endswith(suffix):
|
||||
return s[:-len(suffix)]
|
||||
else:
|
||||
print(f"Expected suffix ({suffix}) not found: {s}")
|
||||
return s
|
||||
|
||||
if tools:
|
||||
if _outputs_tool_call_tags(chat_format.tool_style):
|
||||
tool_rules = [
|
||||
|
@ -221,6 +234,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
|||
|
||||
@typechecked
|
||||
def parse(s: str) -> Optional[Message]:
|
||||
s = strip_suffix(s)
|
||||
|
||||
# ls = s.lstrip()
|
||||
parts = _tool_call_re.split(s)
|
||||
if len(parts) == 1:
|
||||
|
@ -232,10 +247,21 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
|||
if i % 2 == 0:
|
||||
content.append(part)
|
||||
else:
|
||||
tool_calls.append(json.loads(part))
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=gen_callid(),
|
||||
function=FunctionCall(**json.loads(part))))
|
||||
|
||||
content = ''.join(content).strip()
|
||||
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls)
|
||||
|
||||
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
|
||||
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
|
||||
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
|
||||
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
|
||||
# return None
|
||||
# else:
|
||||
# return Message(role="assistant", content=s)
|
||||
|
||||
return (converter.format_grammar(), parse)
|
||||
|
||||
|
@ -256,7 +282,30 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
|||
|
||||
@typechecked
|
||||
def parse(s: str) -> Optional[Message]:
|
||||
raise NotImplementedError(f'TODO: parse tool_style {chat_format.tool_style}: {s}')
|
||||
s = strip_suffix(s)
|
||||
|
||||
parts = _recipient_content_re.split(s)
|
||||
if len(parts) == 1:
|
||||
return Message(role="assistant", content=s)
|
||||
else:
|
||||
text_content = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
for i in range((len(parts) - 1) // 3):
|
||||
assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}'
|
||||
recipient = parts[i * 3 + 1].strip()
|
||||
content = parts[i * 3 + 2]
|
||||
if recipient == 'all':
|
||||
text_content.append(content)
|
||||
else:
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=gen_callid(),
|
||||
function=FunctionCall(name=recipient, arguments=json.loads(content))))
|
||||
|
||||
assert parts[-1].strip() == '', f'Unexpected content after tool calls: {parts[-1]}'
|
||||
|
||||
content = '\n'.join(text_content).strip()
|
||||
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls if tool_calls else None)
|
||||
|
||||
return (converter.format_grammar(), parse)
|
||||
|
||||
|
@ -265,8 +314,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
|||
|
||||
@typechecked
|
||||
def parse(s: str) -> Optional[Message]:
|
||||
if response_rule.endswith(suffix):
|
||||
return Message(role="assistant", content=s[:-len(suffix)])
|
||||
s = strip_suffix(s)
|
||||
return Message(role="assistant", content=s)
|
||||
|
||||
return (converter.format_grammar(), parse)
|
||||
|
||||
|
@ -275,9 +324,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
|||
|
||||
@typechecked
|
||||
def parse(s: str) -> Optional[Message]:
|
||||
if s.endswith(suffix):
|
||||
return Message(role="assistant", content=s[:-len(suffix)])
|
||||
return None
|
||||
s = strip_suffix(s)
|
||||
return Message(role="assistant", content=s)
|
||||
|
||||
return (None, parse)
|
||||
|
||||
|
|
|
@ -3,22 +3,27 @@
|
|||
|
||||
import json, sys, subprocess, atexit
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
|
||||
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
||||
from examples.openai.api import Message, ChatCompletionRequest
|
||||
from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage
|
||||
from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import random
|
||||
from starlette.responses import StreamingResponse
|
||||
from typing import Annotated, Optional
|
||||
import typer
|
||||
from typeguard import typechecked
|
||||
|
||||
def generate_id(prefix):
|
||||
return f"{prefix}{random.randint(0, 1 << 32)}"
|
||||
|
||||
def main(
|
||||
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
|
||||
|
@ -68,17 +73,13 @@ def main(
|
|||
|
||||
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
|
||||
|
||||
if chat_format.strict_user_assistant_alternation:
|
||||
print("TODO: merge system messages into user messages")
|
||||
# new_messages = []
|
||||
|
||||
# TODO: Test whether the template supports formatting tool_calls
|
||||
|
||||
prompt = chat_format.render(messages, add_generation_prompt=True)
|
||||
print(json.dumps(dict(
|
||||
stream=chat_request.stream,
|
||||
prompt=prompt,
|
||||
grammar=grammar,
|
||||
# grammar=grammar,
|
||||
), indent=2))
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
|
@ -87,7 +88,7 @@ def main(
|
|||
prompt=prompt,
|
||||
stream=chat_request.stream,
|
||||
n_predict=300,
|
||||
grammar=grammar,
|
||||
# grammar=grammar,
|
||||
).model_dump(),
|
||||
headers=headers,
|
||||
timeout=None)
|
||||
|
@ -98,11 +99,35 @@ def main(
|
|||
return StreamingResponse(generate_chunks(response), media_type="text/event-stream")
|
||||
else:
|
||||
result = response.json()
|
||||
if 'content' not in result:
|
||||
# print(json.dumps(result, indent=2))
|
||||
return JSONResponse(result)
|
||||
|
||||
print(json.dumps(result, indent=2))
|
||||
# print(json.dumps(result.get('content'), indent=2))
|
||||
message = parser(result["content"])
|
||||
assert message is not None, f"Failed to parse response: {response.text}"
|
||||
return JSONResponse(message.model_dump())
|
||||
# return JSONResponse(response.json())
|
||||
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
|
||||
|
||||
prompt_tokens=result['timings']['prompt_n']
|
||||
completion_tokens=result['timings']['predicted_n']
|
||||
return JSONResponse(ChatCompletionResponse(
|
||||
id=generate_id('chatcmpl-'),
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model=chat_request.model,
|
||||
choices=[Choice(
|
||||
index=0,
|
||||
message=message,
|
||||
|
||||
finish_reason="stop" if message.tool_calls is None else "tool_calls",
|
||||
)],
|
||||
usage=Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
system_fingerprint='...'
|
||||
).model_dump())
|
||||
|
||||
async def generate_chunks(response):
|
||||
async for chunk in response.aiter_bytes():
|
||||
|
|
|
@ -12,7 +12,9 @@ function cleanup() {
|
|||
trap cleanup EXIT
|
||||
|
||||
echo "# Starting the server"
|
||||
python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
|
||||
|
||||
python -m examples.openai --model ~/AI/Models/functionary-medium-v2.2.q4_0.gguf &
|
||||
# python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
|
||||
# python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
|
||||
SERVER_PID=$!
|
||||
|
||||
|
@ -73,8 +75,8 @@ curl http://localhost:8080/v1/chat/completions \
|
|||
}
|
||||
}],
|
||||
"messages": [
|
||||
{"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
|
||||
{"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
|
||||
{"role": "user", "content": "I live in the UK. what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
|
||||
]
|
||||
}'
|
||||
|
||||
# {"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, List, Set, Tuple, Union
|
||||
from jsonargparse import CLI
|
||||
import json
|
||||
|
||||
class SchemaToTypeScriptConverter:
|
||||
# TODO: comments for arguments!
|
||||
|
@ -14,11 +14,14 @@ class SchemaToTypeScriptConverter:
|
|||
# // where to get weather.
|
||||
# location: string,
|
||||
# }) => any;
|
||||
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
|
||||
return "{" + ', '.join(
|
||||
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], additional_properties: Union[bool, Any]):
|
||||
return "{" + ', '.join([
|
||||
f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
|
||||
for prop_name, prop_schema in properties
|
||||
) + "}"
|
||||
] + (
|
||||
[f"[key: string]: {self.visit(additional_properties)}"]
|
||||
if additional_properties is not None else []
|
||||
)) + "}"
|
||||
|
||||
def visit(self, schema: dict):
|
||||
def print_constant(v):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue