diff --git a/examples/openai/api.py b/examples/openai/api.py
index c44c6bfd1..0d7ddc111 100644
--- a/examples/openai/api.py
+++ b/examples/openai/api.py
@@ -1,10 +1,15 @@
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Literal, Optional, Union
from pydantic import BaseModel, Json
-class ToolCall(BaseModel):
+class FunctionCall(BaseModel):
name: str
arguments: Dict[str, Any]
+class ToolCall(BaseModel):
+ id: Optional[str] = None
+ type: Literal["function"] = "function"
+ function: FunctionCall
+
class Message(BaseModel):
role: str
content: Optional[str]
@@ -30,3 +35,23 @@ class ChatCompletionRequest(BaseModel):
response_format: Optional[ResponseFormat] = None
temperature: float = 1.0
stream: bool = False
+
+class Choice(BaseModel):
+ index: int
+ message: Message
+ logprobs: Optional[Json] = None
+ finish_reason: Union[Literal["stop"], Literal["tool_calls"]]
+
+class Usage(BaseModel):
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+
+class ChatCompletionResponse(BaseModel):
+ id: str
+ object: Literal["chat.completion"]
+ created: int
+ model: str
+ choices: list[Choice]
+ usage: Usage
+ system_fingerprint: str
\ No newline at end of file
diff --git a/examples/openai/prompting.py b/examples/openai/prompting.py
index 0d7e0de56..ea6572d7b 100644
--- a/examples/openai/prompting.py
+++ b/examples/openai/prompting.py
@@ -2,13 +2,14 @@ from enum import Enum
import jinja2
import json
from pathlib import Path
-import sys
+import random
import re
+import sys
from typing import Optional, Tuple, Callable
from typeguard import typechecked
from examples.json_schema_to_grammar import SchemaConverter
-from examples.openai.api import Tool, Message
+from examples.openai.api import Tool, Message, FunctionCall, ToolCall
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.ts_converter import SchemaToTypeScriptConverter
@@ -42,7 +43,7 @@ class ChatFormat:
(i, m) = system_message
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
else:
- return [Message(role="system", content=system_prompt)] + messages
+ return [system_prompt] + messages
@staticmethod
def from_gguf(metadata: GGUFKeyValues):
@@ -69,7 +70,7 @@ class ChatFormat:
i += 1
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
messages = new_messages
- print(f'messages={messages}')
+ # print(f'messages={messages}')
return self.template.render(
messages=messages,
@@ -175,10 +176,15 @@ def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
)
-_tool_call_re = re.compile('(.*?)', re.DOTALL)
-
+_tool_call_re = re.compile(
+ '(.*?)', re.DOTALL)
+_recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL)
+
+def gen_callid():
+ return f'call_{random.randint(0, 1000000)}'
+
@typechecked
-def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]:
+def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[list[Message]]]]:
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
@@ -191,6 +197,13 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
+ def strip_suffix(s: str) -> str:
+ if s.endswith(suffix):
+ return s[:-len(suffix)]
+ else:
+ print(f"Expected suffix ({suffix}) not found: {s}")
+ return s
+
if tools:
if _outputs_tool_call_tags(chat_format.tool_style):
tool_rules = [
@@ -221,6 +234,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked
def parse(s: str) -> Optional[Message]:
+ s = strip_suffix(s)
+
# ls = s.lstrip()
parts = _tool_call_re.split(s)
if len(parts) == 1:
@@ -232,10 +247,21 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
if i % 2 == 0:
content.append(part)
else:
- tool_calls.append(json.loads(part))
+ tool_calls.append(
+ ToolCall(
+ id=gen_callid(),
+ function=FunctionCall(**json.loads(part))))
content = ''.join(content).strip()
return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls)
+
+ # if ''.startswith(ls) or ls.startswith(''):
+ # if ls.startswith('') and ls.endswith('' + suffix):
+ # tool_call = ls[len(''):-len('' + suffix)]
+ # return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
+ # return None
+ # else:
+ # return Message(role="assistant", content=s)
return (converter.format_grammar(), parse)
@@ -256,7 +282,30 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked
def parse(s: str) -> Optional[Message]:
- raise NotImplementedError(f'TODO: parse tool_style {chat_format.tool_style}: {s}')
+ s = strip_suffix(s)
+
+ parts = _recipient_content_re.split(s)
+ if len(parts) == 1:
+ return Message(role="assistant", content=s)
+ else:
+ text_content = []
+ tool_calls: list[ToolCall] = []
+ for i in range((len(parts) - 1) // 3):
+ assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}'
+ recipient = parts[i * 3 + 1].strip()
+ content = parts[i * 3 + 2]
+ if recipient == 'all':
+ text_content.append(content)
+ else:
+ tool_calls.append(
+ ToolCall(
+ id=gen_callid(),
+ function=FunctionCall(name=recipient, arguments=json.loads(content))))
+
+ assert parts[-1].strip() == '', f'Unexpected content after tool calls: {parts[-1]}'
+
+ content = '\n'.join(text_content).strip()
+ return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls if tool_calls else None)
return (converter.format_grammar(), parse)
@@ -265,8 +314,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked
def parse(s: str) -> Optional[Message]:
- if response_rule.endswith(suffix):
- return Message(role="assistant", content=s[:-len(suffix)])
+ s = strip_suffix(s)
+ return Message(role="assistant", content=s)
return (converter.format_grammar(), parse)
@@ -275,9 +324,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op
@typechecked
def parse(s: str) -> Optional[Message]:
- if s.endswith(suffix):
- return Message(role="assistant", content=s[:-len(suffix)])
- return None
+ s = strip_suffix(s)
+ return Message(role="assistant", content=s)
return (None, parse)
diff --git a/examples/openai/server.py b/examples/openai/server.py
index ca1e3eab5..8635da9e5 100644
--- a/examples/openai/server.py
+++ b/examples/openai/server.py
@@ -3,22 +3,27 @@
import json, sys, subprocess, atexit
from pathlib import Path
+import time
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
-from examples.openai.api import Message, ChatCompletionRequest
+from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage
from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import httpx
+import random
from starlette.responses import StreamingResponse
from typing import Annotated, Optional
import typer
from typeguard import typechecked
+def generate_id(prefix):
+ return f"{prefix}{random.randint(0, 1 << 32)}"
+
def main(
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
@@ -68,17 +73,13 @@ def main(
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
- if chat_format.strict_user_assistant_alternation:
- print("TODO: merge system messages into user messages")
- # new_messages = []
-
# TODO: Test whether the template supports formatting tool_calls
prompt = chat_format.render(messages, add_generation_prompt=True)
print(json.dumps(dict(
stream=chat_request.stream,
prompt=prompt,
- grammar=grammar,
+ # grammar=grammar,
), indent=2))
async with httpx.AsyncClient() as client:
response = await client.post(
@@ -87,7 +88,7 @@ def main(
prompt=prompt,
stream=chat_request.stream,
n_predict=300,
- grammar=grammar,
+ # grammar=grammar,
).model_dump(),
headers=headers,
timeout=None)
@@ -98,11 +99,35 @@ def main(
return StreamingResponse(generate_chunks(response), media_type="text/event-stream")
else:
result = response.json()
+ if 'content' not in result:
+ # print(json.dumps(result, indent=2))
+ return JSONResponse(result)
+
print(json.dumps(result, indent=2))
+ # print(json.dumps(result.get('content'), indent=2))
message = parser(result["content"])
- assert message is not None, f"Failed to parse response: {response.text}"
- return JSONResponse(message.model_dump())
- # return JSONResponse(response.json())
+ assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
+
+ prompt_tokens=result['timings']['prompt_n']
+ completion_tokens=result['timings']['predicted_n']
+ return JSONResponse(ChatCompletionResponse(
+ id=generate_id('chatcmpl-'),
+ object="chat.completion",
+ created=int(time.time()),
+ model=chat_request.model,
+ choices=[Choice(
+ index=0,
+ message=message,
+
+ finish_reason="stop" if message.tool_calls is None else "tool_calls",
+ )],
+ usage=Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ ),
+ system_fingerprint='...'
+ ).model_dump())
async def generate_chunks(response):
async for chunk in response.aiter_bytes():
diff --git a/examples/openai/test.sh b/examples/openai/test.sh
index 397682247..7dcc93e45 100755
--- a/examples/openai/test.sh
+++ b/examples/openai/test.sh
@@ -12,7 +12,9 @@ function cleanup() {
trap cleanup EXIT
echo "# Starting the server"
-python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
+
+python -m examples.openai --model ~/AI/Models/functionary-medium-v2.2.q4_0.gguf &
+# python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf &
# python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
SERVER_PID=$!
@@ -73,8 +75,8 @@ curl http://localhost:8080/v1/chat/completions \
}
}],
"messages": [
- {"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
- {"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
+ {"role": "user", "content": "I live in the UK. what is the weather going to be like in San Francisco and Glasgow over the next 4 days."}
]
}'
+# {"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
diff --git a/examples/openai/ts_converter.py b/examples/openai/ts_converter.py
index d018118cb..c0d99d0a4 100644
--- a/examples/openai/ts_converter.py
+++ b/examples/openai/ts_converter.py
@@ -1,5 +1,5 @@
from typing import Any, List, Set, Tuple, Union
-from jsonargparse import CLI
+import json
class SchemaToTypeScriptConverter:
# TODO: comments for arguments!
@@ -14,11 +14,14 @@ class SchemaToTypeScriptConverter:
# // where to get weather.
# location: string,
# }) => any;
- def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
- return "{" + ', '.join(
+ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], additional_properties: Union[bool, Any]):
+ return "{" + ', '.join([
f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
for prop_name, prop_schema in properties
- ) + "}"
+ ] + (
+ [f"[key: string]: {self.visit(additional_properties)}"]
+ if additional_properties is not None else []
+ )) + "}"
def visit(self, schema: dict):
def print_constant(v):