From aa9605c7514531d6c2bbfef43a8e4bf801c925c7 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 27 Mar 2024 01:50:26 +0000 Subject: [PATCH] server.py: kinda api-compliant output, disabled grammar --- examples/openai/api.py | 29 ++++++++++++- examples/openai/prompting.py | 76 +++++++++++++++++++++++++++------ examples/openai/server.py | 45 ++++++++++++++----- examples/openai/test.sh | 8 ++-- examples/openai/ts_converter.py | 11 +++-- 5 files changed, 136 insertions(+), 33 deletions(-) diff --git a/examples/openai/api.py b/examples/openai/api.py index c44c6bfd1..0d7ddc111 100644 --- a/examples/openai/api.py +++ b/examples/openai/api.py @@ -1,10 +1,15 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional, Union from pydantic import BaseModel, Json -class ToolCall(BaseModel): +class FunctionCall(BaseModel): name: str arguments: Dict[str, Any] +class ToolCall(BaseModel): + id: Optional[str] = None + type: Literal["function"] = "function" + function: FunctionCall + class Message(BaseModel): role: str content: Optional[str] @@ -30,3 +35,23 @@ class ChatCompletionRequest(BaseModel): response_format: Optional[ResponseFormat] = None temperature: float = 1.0 stream: bool = False + +class Choice(BaseModel): + index: int + message: Message + logprobs: Optional[Json] = None + finish_reason: Union[Literal["stop"], Literal["tool_calls"]] + +class Usage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + +class ChatCompletionResponse(BaseModel): + id: str + object: Literal["chat.completion"] + created: int + model: str + choices: list[Choice] + usage: Usage + system_fingerprint: str \ No newline at end of file diff --git a/examples/openai/prompting.py b/examples/openai/prompting.py index 0d7e0de56..ea6572d7b 100644 --- a/examples/openai/prompting.py +++ b/examples/openai/prompting.py @@ -2,13 +2,14 @@ from enum import Enum import jinja2 import json from pathlib import Path -import sys +import random import re +import sys from typing import Optional, Tuple, Callable from typeguard import typechecked from examples.json_schema_to_grammar import SchemaConverter -from examples.openai.api import Tool, Message +from examples.openai.api import Tool, Message, FunctionCall, ToolCall from examples.openai.gguf_kvs import GGUFKeyValues, Keys from examples.openai.ts_converter import SchemaToTypeScriptConverter @@ -42,7 +43,7 @@ class ChatFormat: (i, m) = system_message return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:] else: - return [Message(role="system", content=system_prompt)] + messages + return [system_prompt] + messages @staticmethod def from_gguf(metadata: GGUFKeyValues): @@ -69,7 +70,7 @@ class ChatFormat: i += 1 # print(f'new_messages={json.dumps(new_messages, indent=2)}') messages = new_messages - print(f'messages={messages}') + # print(f'messages={messages}') return self.template.render( messages=messages, @@ -175,10 +176,15 @@ def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool: ToolsPromptStyle.TOOLS_HERMES_2_PRO, ) -_tool_call_re = re.compile('(.*?)', re.DOTALL) - +_tool_call_re = re.compile( + '(.*?)', re.DOTALL) +_recipient_content_re = re.compile(r'(?:(?:<\|(?:stop|from)\|>)+ *assistant\n<\|recipient\|>|^) *([^ <|>\n]+) *\n<\|content\|>(.*?)(?:$|<\|stop\|>\s*$|(?=(?:<\|(?:stop|from)\|>)+ *assistant\n))', re.DOTALL) + +def gen_callid(): + return f'call_{random.randint(0, 1000000)}' + @typechecked -def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]: +def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[list[Message]]]]: converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) @@ -191,6 +197,13 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}" [prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter) + def strip_suffix(s: str) -> str: + if s.endswith(suffix): + return s[:-len(suffix)] + else: + print(f"Expected suffix ({suffix}) not found: {s}") + return s + if tools: if _outputs_tool_call_tags(chat_format.tool_style): tool_rules = [ @@ -221,6 +234,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op @typechecked def parse(s: str) -> Optional[Message]: + s = strip_suffix(s) + # ls = s.lstrip() parts = _tool_call_re.split(s) if len(parts) == 1: @@ -232,10 +247,21 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op if i % 2 == 0: content.append(part) else: - tool_calls.append(json.loads(part)) + tool_calls.append( + ToolCall( + id=gen_callid(), + function=FunctionCall(**json.loads(part)))) content = ''.join(content).strip() return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls) + + # if ''.startswith(ls) or ls.startswith(''): + # if ls.startswith('') and ls.endswith('' + suffix): + # tool_call = ls[len(''):-len('' + suffix)] + # return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)]) + # return None + # else: + # return Message(role="assistant", content=s) return (converter.format_grammar(), parse) @@ -256,7 +282,30 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op @typechecked def parse(s: str) -> Optional[Message]: - raise NotImplementedError(f'TODO: parse tool_style {chat_format.tool_style}: {s}') + s = strip_suffix(s) + + parts = _recipient_content_re.split(s) + if len(parts) == 1: + return Message(role="assistant", content=s) + else: + text_content = [] + tool_calls: list[ToolCall] = [] + for i in range((len(parts) - 1) // 3): + assert parts[i * 3].strip() == '', f'Unexpected content before tool call: {parts[i * 3]}' + recipient = parts[i * 3 + 1].strip() + content = parts[i * 3 + 2] + if recipient == 'all': + text_content.append(content) + else: + tool_calls.append( + ToolCall( + id=gen_callid(), + function=FunctionCall(name=recipient, arguments=json.loads(content)))) + + assert parts[-1].strip() == '', f'Unexpected content after tool calls: {parts[-1]}' + + content = '\n'.join(text_content).strip() + return Message(role="assistant", content=None if content == '' else content, tool_calls=tool_calls if tool_calls else None) return (converter.format_grammar(), parse) @@ -265,8 +314,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op @typechecked def parse(s: str) -> Optional[Message]: - if response_rule.endswith(suffix): - return Message(role="assistant", content=s[:-len(suffix)]) + s = strip_suffix(s) + return Message(role="assistant", content=s) return (converter.format_grammar(), parse) @@ -275,9 +324,8 @@ def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Op @typechecked def parse(s: str) -> Optional[Message]: - if s.endswith(suffix): - return Message(role="assistant", content=s[:-len(suffix)]) - return None + s = strip_suffix(s) + return Message(role="assistant", content=s) return (None, parse) diff --git a/examples/openai/server.py b/examples/openai/server.py index ca1e3eab5..8635da9e5 100644 --- a/examples/openai/server.py +++ b/examples/openai/server.py @@ -3,22 +3,27 @@ import json, sys, subprocess, atexit from pathlib import Path +import time sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest from examples.openai.gguf_kvs import GGUFKeyValues, Keys -from examples.openai.api import Message, ChatCompletionRequest +from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt from fastapi import FastAPI, Request from fastapi.responses import JSONResponse import httpx +import random from starlette.responses import StreamingResponse from typing import Annotated, Optional import typer from typeguard import typechecked +def generate_id(prefix): + return f"{prefix}{random.randint(0, 1 << 32)}" + def main( model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", # model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"), @@ -68,17 +73,13 @@ def main( (grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema) - if chat_format.strict_user_assistant_alternation: - print("TODO: merge system messages into user messages") - # new_messages = [] - # TODO: Test whether the template supports formatting tool_calls prompt = chat_format.render(messages, add_generation_prompt=True) print(json.dumps(dict( stream=chat_request.stream, prompt=prompt, - grammar=grammar, + # grammar=grammar, ), indent=2)) async with httpx.AsyncClient() as client: response = await client.post( @@ -87,7 +88,7 @@ def main( prompt=prompt, stream=chat_request.stream, n_predict=300, - grammar=grammar, + # grammar=grammar, ).model_dump(), headers=headers, timeout=None) @@ -98,11 +99,35 @@ def main( return StreamingResponse(generate_chunks(response), media_type="text/event-stream") else: result = response.json() + if 'content' not in result: + # print(json.dumps(result, indent=2)) + return JSONResponse(result) + print(json.dumps(result, indent=2)) + # print(json.dumps(result.get('content'), indent=2)) message = parser(result["content"]) - assert message is not None, f"Failed to parse response: {response.text}" - return JSONResponse(message.model_dump()) - # return JSONResponse(response.json()) + assert message is not None, f"Failed to parse response:\n{response.text}\n\n" + + prompt_tokens=result['timings']['prompt_n'] + completion_tokens=result['timings']['predicted_n'] + return JSONResponse(ChatCompletionResponse( + id=generate_id('chatcmpl-'), + object="chat.completion", + created=int(time.time()), + model=chat_request.model, + choices=[Choice( + index=0, + message=message, + + finish_reason="stop" if message.tool_calls is None else "tool_calls", + )], + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + system_fingerprint='...' + ).model_dump()) async def generate_chunks(response): async for chunk in response.aiter_bytes(): diff --git a/examples/openai/test.sh b/examples/openai/test.sh index 397682247..7dcc93e45 100755 --- a/examples/openai/test.sh +++ b/examples/openai/test.sh @@ -12,7 +12,9 @@ function cleanup() { trap cleanup EXIT echo "# Starting the server" -python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf & + +python -m examples.openai --model ~/AI/Models/functionary-medium-v2.2.q4_0.gguf & +# python -m examples.openai --model ~/AI/Models/mixtral-8x7b-instruct-v0.1.Q8_0.gguf & # python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf & SERVER_PID=$! @@ -73,8 +75,8 @@ curl http://localhost:8080/v1/chat/completions \ } }], "messages": [ - {"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."}, - {"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days."} + {"role": "user", "content": "I live in the UK. what is the weather going to be like in San Francisco and Glasgow over the next 4 days."} ] }' +# {"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."}, diff --git a/examples/openai/ts_converter.py b/examples/openai/ts_converter.py index d018118cb..c0d99d0a4 100644 --- a/examples/openai/ts_converter.py +++ b/examples/openai/ts_converter.py @@ -1,5 +1,5 @@ from typing import Any, List, Set, Tuple, Union -from jsonargparse import CLI +import json class SchemaToTypeScriptConverter: # TODO: comments for arguments! @@ -14,11 +14,14 @@ class SchemaToTypeScriptConverter: # // where to get weather. # location: string, # }) => any; - def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): - return "{" + ', '.join( + def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], additional_properties: Union[bool, Any]): + return "{" + ', '.join([ f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}' for prop_name, prop_schema in properties - ) + "}" + ] + ( + [f"[key: string]: {self.visit(additional_properties)}"] + if additional_properties is not None else [] + )) + "}" def visit(self, schema: dict): def print_constant(v):