server.py: kinda api-compliant output, disabled grammar
This commit is contained in:
parent
8afd4de17b
commit
aa9605c751
5 changed files with 136 additions and 33 deletions
|
@ -1,10 +1,15 @@
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
from pydantic import BaseModel, Json
|
from pydantic import BaseModel, Json
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class FunctionCall(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
arguments: Dict[str, Any]
|
arguments: Dict[str, Any]
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
id: Optional[str] = None
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: FunctionCall
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: Optional[str]
|
content: Optional[str]
|
||||||
|
@ -30,3 +35,23 @@ class ChatCompletionRequest(BaseModel):
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
temperature: float = 1.0
|
temperature: float = 1.0
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
|
||||||
|
class Choice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: Message
|
||||||
|
logprobs: Optional[Json] = None
|
||||||
|
finish_reason: Union[Literal["stop"], Literal["tool_calls"]]
|
||||||
|
|
||||||
|
class Usage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: Literal["chat.completion"]
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: list[Choice]
|
||||||
|
usage: Usage
|
||||||
|
system_fingerprint: str
|
|
@ -2,13 +2,14 @@ from enum import Enum
|
||||||
import jinja2
|
import jinja2
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import random
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
from typing import Optional, Tuple, Callable
|
from typing import Optional, Tuple, Callable
|
||||||
from typeguard import typechecked
|
from typeguard import typechecked
|
||||||
|
|
||||||
from examples.json_schema_to_grammar import SchemaConverter
|
from examples.json_schema_to_grammar import SchemaConverter
|
||||||
from examples.openai.api import Tool, Message
|
from examples.openai.api import Tool, Message, FunctionCall, ToolCall
|
||||||
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
||||||
from examples.openai.ts_converter import SchemaToTypeScriptConverter
|
from examples.openai.ts_converter import SchemaToTypeScriptConverter
|
||||||
|
|
||||||
|
@ -42,7 +43,7 @@ class ChatFormat:
|
||||||
(i, m) = system_message
|
(i, m) = system_message
|
||||||
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
|
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
|
||||||
else:
|
else:
|
||||||
return [Message(role="system", content=system_prompt)] + messages
|
return [system_prompt] + messages
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_gguf(metadata: GGUFKeyValues):
|
def from_gguf(metadata: GGUFKeyValues):
|
||||||
|
@ -69,7 +70,7 @@ class ChatFormat:
|
||||||
i += 1
|
i += 1
|
||||||
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
|
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
|
||||||
messages = new_messages
|
messages = new_messages
|
||||||
print(f'messages={messages}')
|
# print(f'messages={messages}')
|
||||||
|
|
||||||
return self.template.render(
|
return self.template.render(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -175,10 +176,15 @@ def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
|
||||||
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
|
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
|
||||||
)
|
)
|
||||||
|
|
||||||
_tool_call_re = re.compile('<tool_call>(.*?)</tool_call>', re.DOTALL)
|
_tool_call_re = re.compile(
|
||||||
|
'<tool_call>(.*?)</tool_call>', re.DOTALL)
|
||||||
|
_recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL)
|
||||||
|
|
||||||
|
def gen_callid():
|
||||||
|
return f'call_{random.randint(0, 1000000)}'
|
||||||
|
|
||||||
@typechecked
|
@typechecked
|
||||||
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]:
|
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[list[Message]]]]:
|
||||||
|
|
||||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
|
|
||||||
|
@ -191,6 +197,13 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
|
||||||
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
|
||||||
|
|
||||||
|
def strip_suffix(s: str) -> str:
|
||||||
|
if s.endswith(suffix):
|
||||||
|
return s[:-len(suffix)]
|
||||||
|
else:
|
||||||
|
print(f"Expected suffix ({suffix}) not found: {s}")
|
||||||
|
return s
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
if _outputs_tool_call_tags(chat_format.tool_style):
|
if _outputs_tool_call_tags(chat_format.tool_style):
|
||||||
tool_rules = [
|
tool_rules = [
|
||||||
|
@ -221,6 +234,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
|
|
||||||
@typechecked
|
@typechecked
|
||||||
def parse(s: str) -> Optional[Message]:
|
def parse(s: str) -> Optional[Message]:
|
||||||
|
s = strip_suffix(s)
|
||||||
|
|
||||||
# ls = s.lstrip()
|
# ls = s.lstrip()
|
||||||
parts = _tool_call_re.split(s)
|
parts = _tool_call_re.split(s)
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
|
@ -232,10 +247,21 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
if i % 2 == 0:
|
if i % 2 == 0:
|
||||||
content.append(part)
|
content.append(part)
|
||||||
else:
|
else:
|
||||||
tool_calls.append(json.loads(part))
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=gen_callid(),
|
||||||
|
function=FunctionCall(**json.loads(part))))
|
||||||
|
|
||||||
content = ''.join(content).strip()
|
content = ''.join(content).strip()
|
||||||
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls)
|
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls)
|
||||||
|
|
||||||
|
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
|
||||||
|
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
|
||||||
|
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
|
||||||
|
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
|
||||||
|
# return None
|
||||||
|
# else:
|
||||||
|
# return Message(role="assistant", content=s)
|
||||||
|
|
||||||
return (converter.format_grammar(), parse)
|
return (converter.format_grammar(), parse)
|
||||||
|
|
||||||
|
@ -256,7 +282,30 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
|
|
||||||
@typechecked
|
@typechecked
|
||||||
def parse(s: str) -> Optional[Message]:
|
def parse(s: str) -> Optional[Message]:
|
||||||
raise NotImplementedError(f'TODO: parse tool_style {chat_format.tool_style}: {s}')
|
s = strip_suffix(s)
|
||||||
|
|
||||||
|
parts = _recipient_content_re.split(s)
|
||||||
|
if len(parts) == 1:
|
||||||
|
return Message(role="assistant", content=s)
|
||||||
|
else:
|
||||||
|
text_content = []
|
||||||
|
tool_calls: list[ToolCall] = []
|
||||||
|
for i in range((len(parts) - 1) // 3):
|
||||||
|
assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}'
|
||||||
|
recipient = parts[i * 3 + 1].strip()
|
||||||
|
content = parts[i * 3 + 2]
|
||||||
|
if recipient == 'all':
|
||||||
|
text_content.append(content)
|
||||||
|
else:
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=gen_callid(),
|
||||||
|
function=FunctionCall(name=recipient, arguments=json.loads(content))))
|
||||||
|
|
||||||
|
assert parts[-1].strip() == '', f'Unexpected content after tool calls: {parts[-1]}'
|
||||||
|
|
||||||
|
content = '\n'.join(text_content).strip()
|
||||||
|
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls if tool_calls else None)
|
||||||
|
|
||||||
return (converter.format_grammar(), parse)
|
return (converter.format_grammar(), parse)
|
||||||
|
|
||||||
|
@ -265,8 +314,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
|
|
||||||
@typechecked
|
@typechecked
|
||||||
def parse(s: str) -> Optional[Message]:
|
def parse(s: str) -> Optional[Message]:
|
||||||
if response_rule.endswith(suffix):
|
s = strip_suffix(s)
|
||||||
return Message(role="assistant", content=s[:-len(suffix)])
|
return Message(role="assistant", content=s)
|
||||||
|
|
||||||
return (converter.format_grammar(), parse)
|
return (converter.format_grammar(), parse)
|
||||||
|
|
||||||
|
@ -275,9 +324,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
|
||||||
|
|
||||||
@typechecked
|
@typechecked
|
||||||
def parse(s: str) -> Optional[Message]:
|
def parse(s: str) -> Optional[Message]:
|
||||||
if s.endswith(suffix):
|
s = strip_suffix(s)
|
||||||
return Message(role="assistant", content=s[:-len(suffix)])
|
return Message(role="assistant", content=s)
|
||||||
return None
|
|
||||||
|
|
||||||
return (None, parse)
|
return (None, parse)
|
||||||
|
|
||||||
|
|
|
@ -3,22 +3,27 @@
|
||||||
|
|
||||||
import json, sys, subprocess, atexit
|
import json, sys, subprocess, atexit
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import time
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
|
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
|
||||||
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
||||||
from examples.openai.api import Message, ChatCompletionRequest
|
from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage
|
||||||
from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt
|
from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
import httpx
|
import httpx
|
||||||
|
import random
|
||||||
from starlette.responses import StreamingResponse
|
from starlette.responses import StreamingResponse
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
import typer
|
import typer
|
||||||
from typeguard import typechecked
|
from typeguard import typechecked
|
||||||
|
|
||||||
|
def generate_id(prefix):
|
||||||
|
return f"{prefix}{random.randint(0, 1 << 32)}"
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||||
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
|
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
|
||||||
|
@ -68,17 +73,13 @@ def main(
|
||||||
|
|
||||||
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
|
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
|
||||||
|
|
||||||
if chat_format.strict_user_assistant_alternation:
|
|
||||||
print("TODO: merge system messages into user messages")
|
|
||||||
# new_messages = []
|
|
||||||
|
|
||||||
# TODO: Test whether the template supports formatting tool_calls
|
# TODO: Test whether the template supports formatting tool_calls
|
||||||
|
|
||||||
prompt = chat_format.render(messages, add_generation_prompt=True)
|
prompt = chat_format.render(messages, add_generation_prompt=True)
|
||||||
print(json.dumps(dict(
|
print(json.dumps(dict(
|
||||||
stream=chat_request.stream,
|
stream=chat_request.stream,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
grammar=grammar,
|
# grammar=grammar,
|
||||||
), indent=2))
|
), indent=2))
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
|
@ -87,7 +88,7 @@ def main(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=chat_request.stream,
|
stream=chat_request.stream,
|
||||||
n_predict=300,
|
n_predict=300,
|
||||||
grammar=grammar,
|
# grammar=grammar,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=None)
|
timeout=None)
|
||||||
|
@ -98,11 +99,35 @@ def main(
|
||||||
return StreamingResponse(generate_chunks(response), media_type="text/event-stream")
|
return StreamingResponse(generate_chunks(response), media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
if 'content' not in result:
|
||||||
|
# print(json.dumps(result, indent=2))
|
||||||
|
return JSONResponse(result)
|
||||||
|
|
||||||
print(json.dumps(result, indent=2))
|
print(json.dumps(result, indent=2))
|
||||||
|
# print(json.dumps(result.get('content'), indent=2))
|
||||||
message = parser(result["content"])
|
message = parser(result["content"])
|
||||||
assert message is not None, f"Failed to parse response: {response.text}"
|
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
|
||||||
return JSONResponse(message.model_dump())
|
|
||||||
# return JSONResponse(response.json())
|
prompt_tokens=result['timings']['prompt_n']
|
||||||
|
completion_tokens=result['timings']['predicted_n']
|
||||||
|
return JSONResponse(ChatCompletionResponse(
|
||||||
|
id=generate_id('chatcmpl-'),
|
||||||
|
object="chat.completion",
|
||||||
|
created=int(time.time()),
|
||||||
|
model=chat_request.model,
|
||||||
|
choices=[Choice(
|
||||||
|
index=0,
|
||||||
|
message=message,
|
||||||
|
|
||||||
|
finish_reason="stop" if message.tool_calls is None else "tool_calls",
|
||||||
|
)],
|
||||||
|
usage=Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
),
|
||||||
|
system_fingerprint='...'
|
||||||
|
).model_dump())
|
||||||
|
|
||||||
async def generate_chunks(response):
|
async def generate_chunks(response):
|
||||||
async for chunk in response.aiter_bytes():
|
async for chunk in response.aiter_bytes():
|
||||||
|
|
|
@ -12,7 +12,9 @@ function cleanup() {
|
||||||
trap cleanup EXIT
|
trap cleanup EXIT
|
||||||
|
|
||||||
echo "# Starting the server"
|
echo "# Starting the server"
|
||||||
python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
|
|
||||||
|
python -m examples.openai --model ~/AI/Models/functionary-medium-v2.2.q4_0.gguf &
|
||||||
|
# python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
|
||||||
# python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
|
# python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
|
||||||
SERVER_PID=$!
|
SERVER_PID=$!
|
||||||
|
|
||||||
|
@ -73,8 +75,8 @@ curl http://localhost:8080/v1/chat/completions \
|
||||||
}
|
}
|
||||||
}],
|
}],
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
|
{"role": "user", "content": "I live in the UK. what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
|
||||||
{"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
|
|
||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
|
|
||||||
|
# {"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Any, List, Set, Tuple, Union
|
from typing import Any, List, Set, Tuple, Union
|
||||||
from jsonargparse import CLI
|
import json
|
||||||
|
|
||||||
class SchemaToTypeScriptConverter:
|
class SchemaToTypeScriptConverter:
|
||||||
# TODO: comments for arguments!
|
# TODO: comments for arguments!
|
||||||
|
@ -14,11 +14,14 @@ class SchemaToTypeScriptConverter:
|
||||||
# // where to get weather.
|
# // where to get weather.
|
||||||
# location: string,
|
# location: string,
|
||||||
# }) => any;
|
# }) => any;
|
||||||
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
|
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], additional_properties: Union[bool, Any]):
|
||||||
return "{" + ', '.join(
|
return "{" + ', '.join([
|
||||||
f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
|
f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
|
||||||
for prop_name, prop_schema in properties
|
for prop_name, prop_schema in properties
|
||||||
) + "}"
|
] + (
|
||||||
|
[f"[key: string]: {self.visit(additional_properties)}"]
|
||||||
|
if additional_properties is not None else []
|
||||||
|
)) + "}"
|
||||||
|
|
||||||
def visit(self, schema: dict):
|
def visit(self, schema: dict):
|
||||||
def print_constant(v):
|
def print_constant(v):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue