diff --git a/examples/openai/README.md b/examples/openai/README.md
new file mode 100644
index 000000000..47c9c67cc
--- /dev/null
+++ b/examples/openai/README.md
@@ -0,0 +1,53 @@
+# examples.openai: OpenAI API-compatible server
+
+A simple Python server that sits above the C++ [../server](examples/server) and offers improved OAI compatibility.
+
+## Usage
+
+```bash
+python -m examples.openai -m some-model.gguf
+
+
+```
+
+## Features
+
+The new examples/openai/server.py:
+
+- Uses llama.cpp C++ server as a backend (spawns it or connects to existing)
+
+- Uses actual jinja2 chat templates read from the models
+
+- Supports grammar-constrained output for both JSON response format and tool calls
+
+- Tool calling “works” w/ all models (even non-specialized ones like Mixtral 7x8B)
+
+ - Optimised support for Functionary & Nous Hermes, easy to extend to other tool-calling fine-tunes
+
+## TODO
+
+- Embedding endpoint w/ distinct server subprocess
+
+- Automatic/manual session caching
+
+ - Spawns the main C++ CLI under the hood
+
+ - Support precaching long prompts from CLI
+
+ - Instant incremental inference in long threads
+
+- Improve examples/agent:
+
+ - Interactive agent CLI that auto-discovers tools from OpenAPI endpoints
+
+ - Script that wraps any Python source as a container-sandboxed OpenAPI endpoint (allowing running ~unsafe code w/ tools)
+
+ - Basic memory / RAG / python interpreter tools
+
+- Follow-ups
+
+ - Remove OAI support from server
+
+ - Remove non-Python json schema to grammar converters
+
+ - Reach out to frameworks to advertise new option.
diff --git a/examples/openai/__main__.py b/examples/openai/__main__.py
new file mode 100644
index 000000000..5204826b2
--- /dev/null
+++ b/examples/openai/__main__.py
@@ -0,0 +1,8 @@
+
+from jsonargparse import CLI
+
+from examples.openai.server import main
+
+if __name__ == "__main__":
+ CLI(main)
+
diff --git a/examples/openai/api.py b/examples/openai/api.py
new file mode 100644
index 000000000..b883ecec4
--- /dev/null
+++ b/examples/openai/api.py
@@ -0,0 +1,27 @@
+from typing import Any, Optional
+from pydantic import BaseModel, Json
+
+class Message(BaseModel):
+ role: str
+ content: str
+
+class ToolFunction(BaseModel):
+ name: str
+ description: str
+ parameters: Any
+
+class Tool(BaseModel):
+ type: str
+ function: ToolFunction
+
+class ResponseFormat(BaseModel):
+ type: str
+ json_schema: Optional[Any] = None
+
+class ChatCompletionRequest(BaseModel):
+ model: str
+ tools: Optional[list[Tool]] = None
+ messages: list[Message]
+ response_format: Optional[ResponseFormat] = None
+ temperature: float = 1.0
+ stream: bool = False
diff --git a/examples/openai/chat_format.py b/examples/openai/chat_format.py
new file mode 100644
index 000000000..bb7d0c94c
--- /dev/null
+++ b/examples/openai/chat_format.py
@@ -0,0 +1,59 @@
+from enum import StrEnum
+import jinja2
+
+from examples.openai.gguf_kvs import GGUFKeyValues, Keys
+
+def raise_exception(msg: str):
+ raise Exception(msg)
+
+class ToolStyle(StrEnum):
+ # https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models
+ DEFAULT="Default",
+ # https://github.com/MeetKai/functionary
+ # TODO: look at https://github.com/ggerganov/llama.cpp/pull/5695
+ # https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py
+ FUNCTIONARY_V2="Functionary V2",
+ # https://github.com/NousResearch/Hermes-Function-Calling
+ NOUS_RESEARCH_HERMES="Nous-Research-Hermes-Function-Calling",
+
+class ChatFormat: #(BaseModel):
+ def __init__(self, template: str, eos_token: str, bos_token: str):
+ env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
+ self.template = env.from_string(template)
+ self.eos_token = eos_token
+ self.bos_token = bos_token
+
+ self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template
+
+ if "<|recipient|>' + tool_call['function']['name']" in template:
+ self.tool_style = ToolStyle.FUNCTIONARY_V2
+ else:
+ self.tool_style = ToolStyle.DEFAULT
+
+
+ def __str__(self):
+ return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})"
+
+
+ @staticmethod
+ def from_gguf(metadata: GGUFKeyValues):
+ return ChatFormat(
+ template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
+ bos_token = metadata[Keys.Tokenizer.BOS_ID],
+ eos_token = metadata[Keys.Tokenizer.EOS_ID])
+ # @staticmethod
+ # def from_gguf(model: Path):
+ # reader = GGUFReader(model.as_posix())
+ # return ChatFormat(
+ # template = reader.fields[Keys.Tokenizer.CHAT_TEMPLATE].read(),
+ # bos_token = reader.fields[Keys.Tokenizer.BOS_ID].read(),
+ # eos_token = reader.fields[Keys.Tokenizer.EOS_ID].read())
+
+ def render(self, messages: list[dict], add_generation_prompt: bool, omit_bos: bool = False):
+ return self.template.render(
+ messages=messages,
+ eos_token=self.eos_token,
+ bos_token='' if omit_bos else self.bos_token,
+ raise_exception=raise_exception,
+ add_generation_prompt=add_generation_prompt,
+ )
diff --git a/examples/openai/gguf_kvs.py b/examples/openai/gguf_kvs.py
new file mode 100644
index 000000000..2eba427b3
--- /dev/null
+++ b/examples/openai/gguf_kvs.py
@@ -0,0 +1,20 @@
+from pathlib import Path
+import sys
+
+sys.path.insert(0, str(Path(__file__).parent.parent.parent / "gguf-py"))
+
+from gguf.gguf_reader import GGUFReader
+from gguf.constants import Keys
+
+class GGUFKeyValues:
+ def __init__(self, model: Path):
+ reader = GGUFReader(model.as_posix())
+ self.fields = reader.fields
+ def __getitem__(self, key: str):
+ if '{arch}' in key:
+ key = key.replace('{arch}', self[Keys.General.ARCHITECTURE])
+ return self.fields[key].read()
+ def __contains__(self, key: str):
+ return key in self.fields
+ def keys(self):
+ return self.fields.keys()
diff --git a/examples/openai/llama_cpp_server_api.py b/examples/openai/llama_cpp_server_api.py
new file mode 100644
index 000000000..936900728
--- /dev/null
+++ b/examples/openai/llama_cpp_server_api.py
@@ -0,0 +1,28 @@
+from typing import Optional
+from pydantic import BaseModel, Json
+
+class LlamaCppServerCompletionRequest(BaseModel):
+ prompt: str
+ stream: Optional[bool] = None
+ cache_prompt: Optional[bool] = None
+ n_predict: Optional[int] = None
+ top_k: Optional[int] = None
+ top_p: Optional[float] = None
+ min_p: Optional[float] = None
+ tfs_z: Optional[float] = None
+ typical_p: Optional[float] = None
+ temperature: Optional[float] = None
+ dynatemp_range: Optional[float] = None
+ dynatemp_exponent: Optional[float] = None
+ repeat_last_n: Optional[int] = None
+ repeat_penalty: Optional[float] = None
+ frequency_penalty: Optional[float] = None
+ presence_penalty: Optional[float] = None
+ mirostat: Optional[bool] = None
+ mirostat_tau: Optional[float] = None
+ mirostat_eta: Optional[float] = None
+ penalize_nl: Optional[bool] = None
+ n_keep: Optional[int] = None
+ seed: Optional[int] = None
+ grammar: Optional[str] = None
+ json_schema: Optional[Json] = None
\ No newline at end of file
diff --git a/examples/openai/requirements.txt b/examples/openai/requirements.txt
new file mode 100644
index 000000000..219fda417
--- /dev/null
+++ b/examples/openai/requirements.txt
@@ -0,0 +1,7 @@
+fastapi[all]
+gguf
+jinja2
+jsonargparse
+pydantic
+sse-starlette
+uvicorn[all]
\ No newline at end of file
diff --git a/examples/openai/server.py b/examples/openai/server.py
new file mode 100644
index 000000000..db075bddb
--- /dev/null
+++ b/examples/openai/server.py
@@ -0,0 +1,215 @@
+import json, sys, subprocess, atexit
+from pathlib import Path
+
+# sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
+from examples.json_schema_to_grammar import SchemaConverter
+
+from typing import Optional
+import httpx
+from fastapi import Depends, FastAPI, Request, Response
+from starlette.responses import StreamingResponse
+from fastapi.responses import JSONResponse
+from jsonargparse import CLI
+
+from examples.openai.ts_converter import SchemaToTypeScriptConverter
+from examples.openai.gguf_kvs import GGUFKeyValues, Keys
+from examples.openai.api import Message, Tool, ToolFunction, ResponseFormat, ChatCompletionRequest
+from examples.openai.chat_format import ChatFormat, ToolStyle
+
+def _add_system_prompt(messages: list['Message'], system_prompt: str):
+ # TODO: add to last system message, or create a new one just before the last user message
+ system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
+ if system_message is not None:
+ (i, m) = system_message
+ messages[i].content = m.content + '\n' + system_prompt
+ else:
+ messages.insert(0, Message(role="system", content=system_prompt))
+ return messages
+
+def main(
+ model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
+ host: str = "localhost",
+ port: int = 8080,
+ main_server_endpoint: Optional[str] = None,
+ main_server_host: str = "localhost",
+ main_server_port: Optional[int] = 8081,
+):
+ import uvicorn
+
+ metadata = GGUFKeyValues(model)
+ context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
+ chat_format = ChatFormat.from_gguf(metadata)
+ print(chat_format)
+
+ if not main_server_endpoint:
+ server_process = subprocess.Popen([
+ "./server", "-m", model,
+ "--host", main_server_host, "--port", f'{main_server_port}',
+ '-ctk', 'q4_0', '-ctv', 'f16',
+ "-c", f"8192",
+ # "-c", f"{context_length}",
+ ])
+ atexit.register(server_process.kill)
+ main_server_endpoint = f"http://{main_server_host}:{main_server_port}"
+
+ app = FastAPI()
+
+ @app.post("/v1/chat/completions")
+ async def chat_completions(request: Request, chat_request: ChatCompletionRequest):
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": request.headers.get("Authorization"),
+ }
+
+ if chat_request.response_format is not None:
+ assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}"
+ response_schema = chat_request.response_format.json_schema or {}
+ else:
+ response_schema = None
+
+ messages = chat_request.messages
+ parser=None
+ grammar=None
+
+ converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
+
+ response_rule = converter.visit(response_schema, "response") if response_schema else None
+
+
+ delimiter = '<%$[SAMPLE]$%>'
+ empty_prompt = chat_format.render([], add_generation_prompt=True)
+ planted_prompt = chat_format.render([{"role": "assistant", "content": delimiter}], add_generation_prompt=False)
+ assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
+ [prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
+
+ if chat_request.tools:
+ if chat_format.tool_style in (ToolStyle.DEFAULT, ToolStyle.NOUS_RESEARCH_HERMES):
+ messages = _add_system_prompt(messages, '\n'.join([
+ 'Here are the tools available:',
+ '',
+ *(tool.model_dump_json() for tool in chat_request.tools),
+ '',
+ ]))
+
+ tool_rules = [
+ converter.visit(
+ dict(
+ type="object",
+ properties=dict(
+ name=dict(const=tool.function.name),
+ arguments=tool.function.parameters,
+ ),
+ required=['name', 'arguments']
+ ),
+ f'{tool.function.name}-tool-call'
+ )
+ for tool in chat_request.tools
+ ]
+
+ # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
+ # OR a tool-call message respecting the schema of any of the tools
+ converter._add_rule(
+ "root",
+ converter._format_literal(prefix) + " (" +
+ (response_rule or converter.not_literal("")) + " | " +
+ converter._format_literal("") + " (" +
+ ' | '.join(tool_rules) +
+ ") " + converter._format_literal("") +
+ ") " + converter._format_literal(suffix))
+ grammar = converter.format_grammar()
+
+ def parse(s: str):
+ if ''.startswith(s):
+ if s.startswith('') and s.endswith('' + suffix):
+ s = s[len(''):-len('' + suffix)]
+ return {"role": "assistant", "tool_calls": [json.loads(s)]}
+ return None
+ else:
+ return {"role": "assistant", "content": s}
+
+ parser = parse
+
+ elif chat_format.tool_style == ToolStyle.FUNCTIONARY_V2:
+
+ ts_converter = SchemaToTypeScriptConverter()
+
+ messages = _add_system_prompt(messages, '\n'.join([
+ '// Supported function definitions that should be called when necessary.'
+ 'namespace functions {',
+ *[
+ '// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
+ 'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n"
+ for tool in chat_request.tools
+ ],
+ '} // namespace functions',
+ ]))
+
+ # Only allowing a single tool call at a time for now.
+ # Note that if there were more, they'd be separated by a '<|from|>assistant' literal
+ converter._add_rule(
+ "root",
+ converter._format_literal(prefix) + " (" +
+ (response_rule or converter.not_literal("<|recipient|>")) + " | " +
+ (' | '.join(
+ converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
+ converter.visit(tool.function.parameters, tool.function.name + '-args')
+ for tool in chat_request.tools
+ )) +
+ ") " +
+ ") " + converter._format_literal(suffix))
+ grammar = converter.format_grammar()
+ else:
+ raise NotImplementedError(f'Unsupported tool_style: {chat_format.tool_style}')
+
+ elif response_schema:
+ converter._add_rule('root', response_rule)
+ grammar = converter.format_grammar()
+
+ def parse(s):
+ if s.endswith(suffix):
+ s = s[:-len(suffix)]
+ return {"role": "assistant", "content": s}
+ return None
+
+ parser = parse
+
+ if chat_format.strict_user_assistant_alternation:
+ print("TODO: merge system messages into user messages")
+ # new_messages = []
+
+ # TODO: Test whether the template supports formatting tool_calls
+
+ prompt = chat_format.render(messages, add_generation_prompt=True)
+ # print(prompt)
+ # print(grammar)
+ print(json.dumps(dict(
+ prompt=prompt,
+ stream=chat_request.stream,
+ grammar=grammar,
+ ), indent=2))
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ f"{main_server_endpoint}/completions",
+ json=LlamaCppServerCompletionRequest(
+ prompt=prompt,
+ stream=chat_request.stream,
+ n_predict=100,
+ grammar=grammar,
+ ).model_dump(),
+ headers=headers,
+ timeout=None)
+
+ return StreamingResponse(generate_chunks(response), media_type="text/event-stream") if chat_request.stream \
+ else JSONResponse(response.json())
+
+ async def generate_chunks(response):
+ async for chunk in response.aiter_bytes():
+ yield chunk
+
+ uvicorn.run(app, host=host, port=port)
+
+if __name__ == "__main__":
+ CLI(main)
+
diff --git a/examples/openai/ts_converter.py b/examples/openai/ts_converter.py
new file mode 100644
index 000000000..d018118cb
--- /dev/null
+++ b/examples/openai/ts_converter.py
@@ -0,0 +1,85 @@
+from typing import Any, List, Set, Tuple, Union
+from jsonargparse import CLI
+
+class SchemaToTypeScriptConverter:
+ # TODO: comments for arguments!
+ # // Get the price of a particular car model
+ # type get_car_price = (_: {
+ # // The name of the car model.
+ # car_name: string,
+ # }) => any;
+
+ # // get the weather of a location
+ # type get_weather = (_: {
+ # // where to get weather.
+ # location: string,
+ # }) => any;
+ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
+ return "{" + ', '.join(
+ f'{prop_name}{"" if prop_name in required else "?"}: {self.visit(prop_schema)}'
+ for prop_name, prop_schema in properties
+ ) + "}"
+
+ def visit(self, schema: dict):
+ def print_constant(v):
+ return json.dumps(v)
+
+ schema_type = schema.get('type')
+ schema_format = schema.get('format')
+
+ if 'oneOf' in schema or 'anyOf' in schema:
+ return '|'.join(self.visit(s) for s in schema.get('oneOf') or schema.get('anyOf'))
+
+ elif isinstance(schema_type, list):
+ return '|'.join(self.visit({'type': t}) for t in schema_type)
+
+ elif 'const' in schema:
+ return print_constant(schema['const'])
+
+ elif 'enum' in schema:
+ return '|'.join((print_constant(v) for v in schema['enum']))
+
+ elif schema_type in (None, 'object') and \
+ ('properties' in schema or \
+ ('additionalProperties' in schema and schema['additionalProperties'] is not True)):
+ required = set(schema.get('required', []))
+ properties = list(schema.get('properties', {}).items())
+ return self._build_object_rule(properties, required, schema.get('additionalProperties'))
+
+ elif schema_type in (None, 'object') and 'allOf' in schema:
+ required = set()
+ properties = []
+ def add_component(comp_schema, is_required):
+ if (ref := comp_schema.get('$ref')) is not None:
+ comp_schema = self._refs[ref]
+
+ if 'properties' in comp_schema:
+ for prop_name, prop_schema in comp_schema['properties'].items():
+ properties.append((prop_name, prop_schema))
+ if is_required:
+ required.add(prop_name)
+
+ for t in schema['allOf']:
+ if 'anyOf' in t:
+ for tt in t['anyOf']:
+ add_component(tt, is_required=False)
+ else:
+ add_component(t, is_required=True)
+
+ return self._build_object_rule(properties, required, additional_properties=[])
+
+ elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
+ items = schema.get('items') or schema['prefixItems']
+ if isinstance(items, list):
+ return '[' + ', '.join(self.visit(item) for item in items) + '][]'
+ else:
+ return self.visit(items) + '[]'
+
+ elif schema_type in (None, 'string') and schema_format == 'date-time':
+ return 'Date'
+
+ elif (schema_type == 'object') or (len(schema) == 0):
+ return 'any'
+
+ else:
+ return 'number' if schema_type == 'integer' else schema_type