From 63d13245e1668b01533765e00958c19b27df29fc Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 25 Mar 2024 23:57:25 +0000 Subject: [PATCH] server.py: hacky code --- examples/openai/README.md | 53 ++++++ examples/openai/__main__.py | 8 + examples/openai/api.py | 27 +++ examples/openai/chat_format.py | 59 +++++++ examples/openai/gguf_kvs.py | 20 +++ examples/openai/llama_cpp_server_api.py | 28 +++ examples/openai/requirements.txt | 7 + examples/openai/server.py | 215 ++++++++++++++++++++++++ examples/openai/ts_converter.py | 85 ++++++++++ 9 files changed, 502 insertions(+) create mode 100644 examples/openai/README.md create mode 100644 examples/openai/__main__.py create mode 100644 examples/openai/api.py create mode 100644 examples/openai/chat_format.py create mode 100644 examples/openai/gguf_kvs.py create mode 100644 examples/openai/llama_cpp_server_api.py create mode 100644 examples/openai/requirements.txt create mode 100644 examples/openai/server.py create mode 100644 examples/openai/ts_converter.py diff --git a/examples/openai/README.md b/examples/openai/README.md new file mode 100644 index 000000000..47c9c67cc --- /dev/null +++ b/examples/openai/README.md @@ -0,0 +1,53 @@ +# examples.openai: OpenAI API-compatible server + +A simple Python server that sits above the C++ [../server](examples/server) and offers improved OAI compatibility. + +## Usage + +```bash +python -m examples.openai -m some-model.gguf + + +``` + +## Features + +The new examples/openai/server.py: + +- Uses llama.cpp C++ server as a backend (spawns it or connects to existing) + +- Uses actual jinja2 chat templates read from the models + +- Supports grammar-constrained output for both JSON response format and tool calls + +- Tool calling “works” w/ all models (even non-specialized ones like Mixtral 7x8B) + + - Optimised support for Functionary & Nous Hermes, easy to extend to other tool-calling fine-tunes + +## TODO + +- Embedding endpoint w/ distinct server subprocess + +- Automatic/manual session caching + + - Spawns the main C++ CLI under the hood + + - Support precaching long prompts from CLI + + - Instant incremental inference in long threads + +- Improve examples/agent: + + - Interactive agent CLI that auto-discovers tools from OpenAPI endpoints + + - Script that wraps any Python source as a container-sandboxed OpenAPI endpoint (allowing running ~unsafe code w/ tools) + + - Basic memory / RAG / python interpreter tools + +- Follow-ups + + - Remove OAI support from server + + - Remove non-Python json schema to grammar converters + + - Reach out to frameworks to advertise new option. diff --git a/examples/openai/__main__.py b/examples/openai/__main__.py new file mode 100644 index 000000000..5204826b2 --- /dev/null +++ b/examples/openai/__main__.py @@ -0,0 +1,8 @@ + +from jsonargparse import CLI + +from examples.openai.server import main + +if __name__ == "__main__": + CLI(main) + diff --git a/examples/openai/api.py b/examples/openai/api.py new file mode 100644 index 000000000..b883ecec4 --- /dev/null +++ b/examples/openai/api.py @@ -0,0 +1,27 @@ +from typing import Any, Optional +from pydantic import BaseModel, Json + +class Message(BaseModel): + role: str + content: str + +class ToolFunction(BaseModel): + name: str + description: str + parameters: Any + +class Tool(BaseModel): + type: str + function: ToolFunction + +class ResponseFormat(BaseModel): + type: str + json_schema: Optional[Any] = None + +class ChatCompletionRequest(BaseModel): + model: str + tools: Optional[list[Tool]] = None + messages: list[Message] + response_format: Optional[ResponseFormat] = None + temperature: float = 1.0 + stream: bool = False diff --git a/examples/openai/chat_format.py b/examples/openai/chat_format.py new file mode 100644 index 000000000..bb7d0c94c --- /dev/null +++ b/examples/openai/chat_format.py @@ -0,0 +1,59 @@ +from enum import StrEnum +import jinja2 + +from examples.openai.gguf_kvs import GGUFKeyValues, Keys + +def raise_exception(msg: str): + raise Exception(msg) + +class ToolStyle(StrEnum): + # https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models + DEFAULT="Default", + # https://github.com/MeetKai/functionary + # TODO: look at https://github.com/ggerganov/llama.cpp/pull/5695 + # https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py + FUNCTIONARY_V2="Functionary V2", + # https://github.com/NousResearch/Hermes-Function-Calling + NOUS_RESEARCH_HERMES="Nous-Research-Hermes-Function-Calling", + +class ChatFormat: #(BaseModel): + def __init__(self, template: str, eos_token: str, bos_token: str): + env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True) + self.template = env.from_string(template) + self.eos_token = eos_token + self.bos_token = bos_token + + self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template + + if "<|recipient|>' + tool_call['function']['name']" in template: + self.tool_style = ToolStyle.FUNCTIONARY_V2 + else: + self.tool_style = ToolStyle.DEFAULT + + + def __str__(self): + return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})" + + + @staticmethod + def from_gguf(metadata: GGUFKeyValues): + return ChatFormat( + template = metadata[Keys.Tokenizer.CHAT_TEMPLATE], + bos_token = metadata[Keys.Tokenizer.BOS_ID], + eos_token = metadata[Keys.Tokenizer.EOS_ID]) + # @staticmethod + # def from_gguf(model: Path): + # reader = GGUFReader(model.as_posix()) + # return ChatFormat( + # template = reader.fields[Keys.Tokenizer.CHAT_TEMPLATE].read(), + # bos_token = reader.fields[Keys.Tokenizer.BOS_ID].read(), + # eos_token = reader.fields[Keys.Tokenizer.EOS_ID].read()) + + def render(self, messages: list[dict], add_generation_prompt: bool, omit_bos: bool = False): + return self.template.render( + messages=messages, + eos_token=self.eos_token, + bos_token='' if omit_bos else self.bos_token, + raise_exception=raise_exception, + add_generation_prompt=add_generation_prompt, + ) diff --git a/examples/openai/gguf_kvs.py b/examples/openai/gguf_kvs.py new file mode 100644 index 000000000..2eba427b3 --- /dev/null +++ b/examples/openai/gguf_kvs.py @@ -0,0 +1,20 @@ +from pathlib import Path +import sys + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "gguf-py")) + +from gguf.gguf_reader import GGUFReader +from gguf.constants import Keys + +class GGUFKeyValues: + def __init__(self, model: Path): + reader = GGUFReader(model.as_posix()) + self.fields = reader.fields + def __getitem__(self, key: str): + if '{arch}' in key: + key = key.replace('{arch}', self[Keys.General.ARCHITECTURE]) + return self.fields[key].read() + def __contains__(self, key: str): + return key in self.fields + def keys(self): + return self.fields.keys() diff --git a/examples/openai/llama_cpp_server_api.py b/examples/openai/llama_cpp_server_api.py new file mode 100644 index 000000000..936900728 --- /dev/null +++ b/examples/openai/llama_cpp_server_api.py @@ -0,0 +1,28 @@ +from typing import Optional +from pydantic import BaseModel, Json + +class LlamaCppServerCompletionRequest(BaseModel): + prompt: str + stream: Optional[bool] = None + cache_prompt: Optional[bool] = None + n_predict: Optional[int] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + min_p: Optional[float] = None + tfs_z: Optional[float] = None + typical_p: Optional[float] = None + temperature: Optional[float] = None + dynatemp_range: Optional[float] = None + dynatemp_exponent: Optional[float] = None + repeat_last_n: Optional[int] = None + repeat_penalty: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + mirostat: Optional[bool] = None + mirostat_tau: Optional[float] = None + mirostat_eta: Optional[float] = None + penalize_nl: Optional[bool] = None + n_keep: Optional[int] = None + seed: Optional[int] = None + grammar: Optional[str] = None + json_schema: Optional[Json] = None \ No newline at end of file diff --git a/examples/openai/requirements.txt b/examples/openai/requirements.txt new file mode 100644 index 000000000..219fda417 --- /dev/null +++ b/examples/openai/requirements.txt @@ -0,0 +1,7 @@ +fastapi[all] +gguf +jinja2 +jsonargparse +pydantic +sse-starlette +uvicorn[all] \ No newline at end of file diff --git a/examples/openai/server.py b/examples/openai/server.py new file mode 100644 index 000000000..db075bddb --- /dev/null +++ b/examples/openai/server.py @@ -0,0 +1,215 @@ +import json, sys, subprocess, atexit +from pathlib import Path + +# sys.path.insert(0, str(Path(__file__).parent.parent)) + +from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest +from examples.json_schema_to_grammar import SchemaConverter + +from typing import Optional +import httpx +from fastapi import Depends, FastAPI, Request, Response +from starlette.responses import StreamingResponse +from fastapi.responses import JSONResponse +from jsonargparse import CLI + +from examples.openai.ts_converter import SchemaToTypeScriptConverter +from examples.openai.gguf_kvs import GGUFKeyValues, Keys +from examples.openai.api import Message, Tool, ToolFunction, ResponseFormat, ChatCompletionRequest +from examples.openai.chat_format import ChatFormat, ToolStyle + +def _add_system_prompt(messages: list['Message'], system_prompt: str): + # TODO: add to last system message, or create a new one just before the last user message + system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None) + if system_message is not None: + (i, m) = system_message + messages[i].content = m.content + '\n' + system_prompt + else: + messages.insert(0, Message(role="system", content=system_prompt)) + return messages + +def main( + model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"), + host: str = "localhost", + port: int = 8080, + main_server_endpoint: Optional[str] = None, + main_server_host: str = "localhost", + main_server_port: Optional[int] = 8081, +): + import uvicorn + + metadata = GGUFKeyValues(model) + context_length = metadata[Keys.LLM.CONTEXT_LENGTH] + chat_format = ChatFormat.from_gguf(metadata) + print(chat_format) + + if not main_server_endpoint: + server_process = subprocess.Popen([ + "./server", "-m", model, + "--host", main_server_host, "--port", f'{main_server_port}', + '-ctk', 'q4_0', '-ctv', 'f16', + "-c", f"8192", + # "-c", f"{context_length}", + ]) + atexit.register(server_process.kill) + main_server_endpoint = f"http://{main_server_host}:{main_server_port}" + + app = FastAPI() + + @app.post("/v1/chat/completions") + async def chat_completions(request: Request, chat_request: ChatCompletionRequest): + headers = { + "Content-Type": "application/json", + "Authorization": request.headers.get("Authorization"), + } + + if chat_request.response_format is not None: + assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}" + response_schema = chat_request.response_format.json_schema or {} + else: + response_schema = None + + messages = chat_request.messages + parser=None + grammar=None + + converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) + + response_rule = converter.visit(response_schema, "response") if response_schema else None + + + delimiter = '<%$[SAMPLE]$%>' + empty_prompt = chat_format.render([], add_generation_prompt=True) + planted_prompt = chat_format.render([{"role": "assistant", "content": delimiter}], add_generation_prompt=False) + assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}" + [prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter) + + if chat_request.tools: + if chat_format.tool_style in (ToolStyle.DEFAULT, ToolStyle.NOUS_RESEARCH_HERMES): + messages = _add_system_prompt(messages, '\n'.join([ + 'Here are the tools available:', + '', + *(tool.model_dump_json() for tool in chat_request.tools), + '', + ])) + + tool_rules = [ + converter.visit( + dict( + type="object", + properties=dict( + name=dict(const=tool.function.name), + arguments=tool.function.parameters, + ), + required=['name', 'arguments'] + ), + f'{tool.function.name}-tool-call' + ) + for tool in chat_request.tools + ] + + # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not) + # OR a tool-call message respecting the schema of any of the tools + converter._add_rule( + "root", + converter._format_literal(prefix) + " (" + + (response_rule or converter.not_literal("")) + " | " + + converter._format_literal("") + " (" + + ' | '.join(tool_rules) + + ") " + converter._format_literal("") + + ") " + converter._format_literal(suffix)) + grammar = converter.format_grammar() + + def parse(s: str): + if ''.startswith(s): + if s.startswith('') and s.endswith('' + suffix): + s = s[len(''):-len('' + suffix)] + return {"role": "assistant", "tool_calls": [json.loads(s)]} + return None + else: + return {"role": "assistant", "content": s} + + parser = parse + + elif chat_format.tool_style == ToolStyle.FUNCTIONARY_V2: + + ts_converter = SchemaToTypeScriptConverter() + + messages = _add_system_prompt(messages, '\n'.join([ + '// Supported function definitions that should be called when necessary.' + 'namespace functions {', + *[ + '// ' + tool.function.description.replace('\n', '\n// ') + '\n' + '' + 'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n" + for tool in chat_request.tools + ], + '} // namespace functions', + ])) + + # Only allowing a single tool call at a time for now. + # Note that if there were more, they'd be separated by a '<|from|>assistant' literal + converter._add_rule( + "root", + converter._format_literal(prefix) + " (" + + (response_rule or converter.not_literal("<|recipient|>")) + " | " + + (' | '.join( + converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " + + converter.visit(tool.function.parameters, tool.function.name + '-args') + for tool in chat_request.tools + )) + + ") " + + ") " + converter._format_literal(suffix)) + grammar = converter.format_grammar() + else: + raise NotImplementedError(f'Unsupported tool_style: {chat_format.tool_style}') + + elif response_schema: + converter._add_rule('root', response_rule) + grammar = converter.format_grammar() + + def parse(s): + if s.endswith(suffix): + s = s[:-len(suffix)] + return {"role": "assistant", "content": s} + return None + + parser = parse + + if chat_format.strict_user_assistant_alternation: + print("TODO: merge system messages into user messages") + # new_messages = [] + + # TODO: Test whether the template supports formatting tool_calls + + prompt = chat_format.render(messages, add_generation_prompt=True) + # print(prompt) + # print(grammar) + print(json.dumps(dict( + prompt=prompt, + stream=chat_request.stream, + grammar=grammar, + ), indent=2)) + async with httpx.AsyncClient() as client: + response = await client.post( + f"{main_server_endpoint}/completions", + json=LlamaCppServerCompletionRequest( + prompt=prompt, + stream=chat_request.stream, + n_predict=100, + grammar=grammar, + ).model_dump(), + headers=headers, + timeout=None) + + return StreamingResponse(generate_chunks(response), media_type="text/event-stream") if chat_request.stream \ + else JSONResponse(response.json()) + + async def generate_chunks(response): + async for chunk in response.aiter_bytes(): + yield chunk + + uvicorn.run(app, host=host, port=port) + +if __name__ == "__main__": + CLI(main) + diff --git a/examples/openai/ts_converter.py b/examples/openai/ts_converter.py new file mode 100644 index 000000000..d018118cb --- /dev/null +++ b/examples/openai/ts_converter.py @@ -0,0 +1,85 @@ +from typing import Any, List, Set, Tuple, Union +from jsonargparse import CLI + +class SchemaToTypeScriptConverter: + # TODO: comments for arguments! + # // Get the price of a particular car model + # type get_car_price = (_: { + # // The name of the car model. + # car_name: string, + # }) => any; + + # // get the weather of a location + # type get_weather = (_: { + # // where to get weather. + # location: string, + # }) => any; + def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): + return "{" + ', '.join( + f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}' + for prop_name, prop_schema in properties + ) + "}" + + def visit(self, schema: dict): + def print_constant(v): + return json.dumps(v) + + schema_type = schema.get('type') + schema_format = schema.get('format') + + if 'oneOf' in schema or 'anyOf' in schema: + return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf')) + + elif isinstance(schema_type, list): + return '|'.join(self.visit({'type': t}) for t in schema_type) + + elif 'const' in schema: + return print_constant(schema['const']) + + elif 'enum' in schema: + return '|'.join((print_constant(v) for v in schema['enum'])) + + elif schema_type in (None, 'object') and \ + ('properties' in schema or \ + ('additionalProperties' in schema and schema['additionalProperties'] is not True)): + required = set(schema.get('required', [])) + properties = list(schema.get('properties', {}).items()) + return self._build_object_rule(properties, required, schema.get('additionalProperties')) + + elif schema_type in (None, 'object') and 'allOf' in schema: + required = set() + properties = [] + def add_component(comp_schema, is_required): + if (ref := comp_schema.get('$ref')) is not None: + comp_schema = self._refs[ref] + + if 'properties' in comp_schema: + for prop_name, prop_schema in comp_schema['properties'].items(): + properties.append((prop_name, prop_schema)) + if is_required: + required.add(prop_name) + + for t in schema['allOf']: + if 'anyOf' in t: + for tt in t['anyOf']: + add_component(tt, is_required=False) + else: + add_component(t, is_required=True) + + return self._build_object_rule(properties, required, additional_properties=[]) + + elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): + items = schema.get('items') or schema['prefixItems'] + if isinstance(items, list): + return '[' + ', '.join(self.visit(item) for item in items) + '][]' + else: + return self.visit(items) + '[]' + + elif schema_type in (None, 'string') and schema_format == 'date-time': + return 'Date' + + elif (schema_type == 'object') or (len(schema) == 0): + return 'any' + + else: + return 'number' if schema_type == 'integer' else schema_type