server.py: default tools work!

This commit is contained in:
ochafik 2024-03-26 20:58:03 +00:00
parent ffc74360e2
commit d5d9993679
15 changed files with 449 additions and 223 deletions

3
.gitmodules vendored
View file

@ -1,6 +1,3 @@
[submodule "kompute"]
path = kompute
url = https://github.com/nomic-ai/kompute.git
[submodule "examples/agents/hermes_function_calling"]
path = examples/agents/hermes_function_calling
url = https://github.com/NousResearch/Hermes-Function-Calling

View file

@ -1,15 +0,0 @@
Edit `examples/agents/hermes_function_calling/utils.py`:
```py
log_folder = os.environ.get('LOG_FOLDER', os.path.join(script_dir, "inference_logs"))
```
Then run:
```bash
REQUIREMENTS_FILE=<( cat examples/agents/hermes_function_calling/requirements.txt | grep -vE "bitsandbytes|flash-attn" ) \
examples/agents/run_sandboxed_tools.sh \
examples/agents/hermes_function_calling/functions.py \
-e LOG_FOLDER=/data/inference_logs
```

@ -1 +0,0 @@
Subproject commit b4f757e27d87f4ab408f706f482c25a8e1508d59

View file

@ -1,3 +0,0 @@
jsonargparse
pydantic
typer[all]

View file

@ -1,15 +1,45 @@
# examples.openai: OpenAI API-compatible server
# examples.openai: OpenAI API-compatible server + agent / tools examples
A simple Python server that sits above the C++ [../server](examples/server) and offers improved OAI compatibility.
## Usage
Run a simple test:
```bash
python -m examples.openai -m some-model.gguf
# Spawns a Python server (which spawns a C++ Server) then hits it w/ a tool-calling request
examples/openai/test.sh
```
To simply run the Python server (+ C++ server under the hood):
```bash
python -m examples.openai
```
## Tools usage (WIP)
```bash
git clone https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling
```
Then edit `examples/agents/hermes_function_calling/utils.py`:
```py
log_folder = os.environ.get('LOG_FOLDER', os.path.join(script_dir, "inference_logs"))
```
Then run tools in a sandbox:
```bash
REQUIREMENTS_FILE=<( cat examples/agents/hermes_function_calling/requirements.txt | grep -vE "bitsandbytes|flash-attn" ) \
examples/agents/run_sandboxed_tools.sh \
examples/agents/hermes_function_calling/functions.py \
-e LOG_FOLDER=/data/inference_logs
```
TODO: reactor that reads OpenAPI definitions and does the tool calling
## Features
The new examples/openai/server.py:

View file

@ -1,9 +1,14 @@
from typing import Any, Optional
from typing import Any, Dict, Optional
from pydantic import BaseModel, Json
class ToolCall(BaseModel):
name: str
arguments: Dict[str, Any]
class Message(BaseModel):
role: str
content: str
content: Optional[str]
tool_calls: Optional[list[ToolCall]] = None
class ToolFunction(BaseModel):
name: str

View file

@ -1,59 +0,0 @@
from enum import StrEnum
import jinja2
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
def raise_exception(msg: str):
raise Exception(msg)
class ToolStyle(StrEnum):
# https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models
DEFAULT="Default",
# https://github.com/MeetKai/functionary
# TODO: look at https://github.com/ggerganov/llama.cpp/pull/5695
# https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py
FUNCTIONARY_V2="Functionary V2",
# https://github.com/NousResearch/Hermes-Function-Calling
NOUS_RESEARCH_HERMES="Nous-Research-Hermes-Function-Calling",
class ChatFormat: #(BaseModel):
def __init__(self, template: str, eos_token: str, bos_token: str):
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
self.template = env.from_string(template)
self.eos_token = eos_token
self.bos_token = bos_token
self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template
if "<|recipient|>' + tool_call['function']['name']" in template:
self.tool_style = ToolStyle.FUNCTIONARY_V2
else:
self.tool_style = ToolStyle.DEFAULT
def __str__(self):
return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})"
@staticmethod
def from_gguf(metadata: GGUFKeyValues):
return ChatFormat(
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
bos_token = metadata[Keys.Tokenizer.BOS_ID],
eos_token = metadata[Keys.Tokenizer.EOS_ID])
# @staticmethod
# def from_gguf(model: Path):
# reader = GGUFReader(model.as_posix())
# return ChatFormat(
# template = reader.fields[Keys.Tokenizer.CHAT_TEMPLATE].read(),
# bos_token = reader.fields[Keys.Tokenizer.BOS_ID].read(),
# eos_token = reader.fields[Keys.Tokenizer.EOS_ID].read())
def render(self, messages: list[dict], add_generation_prompt: bool, omit_bos: bool = False):
return self.template.render(
messages=messages,
eos_token=self.eos_token,
bos_token='' if omit_bos else self.bos_token,
raise_exception=raise_exception,
add_generation_prompt=add_generation_prompt,
)

View file

@ -8,8 +8,6 @@ from anyio import Path
import fastapi, uvicorn
import typer
# from langchain_core.tools import BaseTool
def load_source_as_module(source):
i = 0
while (module_name := f'mod_{i}') in sys.modules:

View file

@ -0,0 +1,43 @@
<|im_start|>system
Role:
You are a function calling AI agent with self-recursion.
You can call only one function at a time and analyse data you get from function response.
You are provided with function signatures within <tools></tools> XML tags.
The current date is: March 25, 2024.
Objective:
You may use agentic frameworks for reasoning and planning to help with user query.
Please call a function and wait for function results to be provided to you in the next iteration.
Don't make assumptions about what values to plug into function arguments.
Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
Analyze the data once you get the results and call another function.
At each iteration please continue adding the your analysis to previous summary.
Your final response should directly answer the user query with an anlysis or summary of the results of function calls.
Tools:
Here are the available tools:
<tools>
{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"format":{"type":"string","enum":["celsius","fahrenheit"],"description":"The temperature unit to use. Infer this from the users location."}},"required":["location","format"]}}}
{"type":"function","function":{"name":"get_n_day_weather_forecast","description":"Get an N-day weather forecast","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"format":{"type":"string","enum":["celsius","fahrenheit"],"description":"The temperature unit to use. Infer this from the users location."},"num_days":{"type":"integer","description":"The number of days to forecast"}},"required":["location","format","num_days"]}}}
</tools>
If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows:
<tool_call>
{"arguments": {"code_markdown": <python-code>, "name": "code_interpreter"}}
</tool_call>
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
Instructions:
At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
Please keep a running summary with analysis of previous function results and summaries from previous iterations.
Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
Calling multiple functions at once can overload the system and increase cost so call one function at a time please.
If you plan to continue with analysis, always call another function.
For each function call return a valid json object (using doulbe quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"arguments": <args-dict>, "name": <function-name>}
</tool_call>
<|im_end|>
<|im_start|>user
what is the weather going to be like in San Francisco and Glasgow over the next 4 days (temperature in celsius for both)<|im_end|>
<|im_start|>assistant

View file

@ -0,0 +1,242 @@
from enum import Enum
import jinja2
import json
from pathlib import Path
import sys
from typing import Optional, Tuple, Callable
from typeguard import typechecked
from examples.json_schema_to_grammar import SchemaConverter
from examples.openai.api import Tool, Message
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.ts_converter import SchemaToTypeScriptConverter
@typechecked
def raise_exception(msg: str):
raise Exception(msg)
class ChatFormat:
def __init__(self, template: str, eos_token: str, bos_token: str):
env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True)
self.template = env.from_string(template)
self.eos_token = eos_token
self.bos_token = bos_token
self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template
if "<|recipient|>' + tool_call['function']['name']" in template:
self.tool_style = ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2
else:
self.tool_style = ToolsPromptStyle.TOOLS_LONG
def __str__(self):
return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})"
@staticmethod
def from_gguf(metadata: GGUFKeyValues):
return ChatFormat(
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
bos_token = metadata[Keys.Tokenizer.BOS_ID],
eos_token = metadata[Keys.Tokenizer.EOS_ID])
def render(self, messages: list[dict], add_generation_prompt: bool, omit_bos: bool = False):
return self.template.render(
messages=messages,
eos_token=self.eos_token,
bos_token='' if omit_bos else self.bos_token,
raise_exception=raise_exception,
add_generation_prompt=add_generation_prompt,
)
# While the API will be usable with a generic tools usage like OpenAI,
# (see https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models),
# each model may need specific prompting (and/or constrained output,
# especially for models not fine-tuned for tool usage / function calling).
class ToolsPromptStyle(Enum):
# Short prompt w/ <tools>schemas</tools>
TOOLS_SHORT = 1
# Longer prompt w/ <tools>schemas</tools>
TOOLS_LONG = 2
# Large prompt for https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B
# Requires:
# - git clone https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling
# - Set large context length as their prompts are super long
TOOLS_HERMES_2_PRO = 3
# Short prompt w/ TypeScript definitions for https://github.com/MeetKai/functionary
# https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py
# Note: see this prior attempt to support Functionary: https://github.com/ggerganov/llama.cpp/pull/5695
TYPESCRIPT_FUNCTIONARY_V2 = 4
@typechecked
def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> Message:
if chat_format.tool_style == ToolsPromptStyle.TOOLS_SHORT:
return Message(
role="system",
content='\n'.join([
'Here are the tools available:',
'<tools>',
*(json.dumps(tool.model_dump(), indent=indent) for tool in tools),
'</tools>',
])
)
elif chat_format.tool_style == ToolsPromptStyle.TOOLS_LONG:
return Message(
role="system",
content='\n'.join([
'''You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.''',
'''You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:''',
'''<tools>''',
*(json.dumps(tool.model_dump(), indent=indent) for tool in tools),
'''</tools>''',
'',
'''Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}''',
'',
'''For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:''',
'''<tool_call>''',
'''{"arguments": <args-dict>, "name": <function-name>}''',
'''</tool_call>''',
])
)
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
ts_converter = SchemaToTypeScriptConverter()
return Message(
role="system",
content='\n'.join([
'// Supported function definitions that should be called when necessary.'
'namespace functions {',
*[
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n"
for tool in tools
],
'} // namespace functions',
])
)
elif chat_format.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
path = str(Path(__file__).parent / "hermes_function_calling")
if path not in sys.path: sys.path.insert(0, path)
try:
from examples.openai.hermes_function_calling.prompter import PromptManager
except ImportError:
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools])
assert len(prompt) == 1 and prompt[0]["role"] == "system"
return Message(**prompt[0])
else:
raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}")
@typechecked
def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool:
return style in (
ToolsPromptStyle.TOOLS_SHORT,
ToolsPromptStyle.TOOLS_LONG,
ToolsPromptStyle.TOOLS_HERMES_2_PRO,
)
@typechecked
def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]:
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
response_rule = converter.visit(response_schema, "response") if response_schema else None
delimiter = '<%$[SAMPLE]$%>'
empty_prompt = chat_format.render([], add_generation_prompt=True).strip()
planted_prompt = chat_format.render([{"role": "assistant", "content": delimiter}], add_generation_prompt=False).strip()
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
if tools:
if _outputs_tool_call_tags(chat_format.tool_style):
tool_rules = [
converter.visit(
dict(
type="object",
properties=dict(
name=dict(const=tool.function.name),
arguments=tool.function.parameters,
),
required=['name', 'arguments']
),
f'{tool.function.name}-tool-call'
)
for tool in tools
]
# Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
# OR a tool-call message respecting the schema of any of the tools
converter._add_rule(
"root",
converter._format_literal(prefix) + " (" +
(response_rule or converter.not_literal("<tool_call>")) + " | " +
converter._format_literal("<tool_call>") + " (" +
' | '.join(tool_rules) +
") " + converter._format_literal("</tool_call>") +
")") # + converter._format_literal(suffix))
@typechecked
def parse(s: str) -> Optional[Message]:
ls = s.lstrip()
if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
return None
else:
return Message(role="assistant", content=s)
return (converter.format_grammar(), parse)
elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
# Only allowing a single tool call at a time for now.
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal
converter._add_rule(
"root",
converter._format_literal(prefix) + " (" +
(response_rule or converter.not_literal("<|recipient|>")) + " | " +
(' | '.join(
converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
converter.visit(tool.function.parameters, tool.function.name + '-args')
for tool in tools
)) +
") " +
")") # + converter._format_literal(suffix))
@typechecked
def parse(s: str) -> Optional[Message]:
raise NotImplementedError(f'TODO: parse tool_style {chat_format.tool_style}: {s}')
return (converter.format_grammar(), parse)
elif response_schema:
converter._add_rule("root", response_rule + ' ' + converter._format_literal(suffix))
@typechecked
def parse(s: str) -> Optional[Message]:
if response_rule.endswith(suffix):
return Message(role="assistant", content=s[:-len(suffix)])
return (converter.format_grammar(), parse)
else:
converter._add_rule("root", converter._format_literal(prefix) + ' ' + converter._format_literal(suffix))
@typechecked
def parse(s: str) -> Optional[Message]:
if s.endswith(suffix):
return Message(role="assistant", content=s[:-len(suffix)])
return None
return (None, parse)

View file

@ -1,7 +1,7 @@
fastapi[all]
gguf
jinja2
jsonargparse
pydantic
sse-starlette
uvicorn[all]
typer[all]

View file

@ -2,9 +2,9 @@
#
# Runs a Python script in a sandboxed environment and makes its functions available as a web service.
#
# git submodule add https://github.com/NousResearch/Hermes-Function-Calling examples/agents/hermes_function_calling
# python examples/agents/fastify.py examples/agents/hermes_function_calling/functions.py
# REQUIREMENTS_FILE=<( cat examples/agents/hermes_function_calling/requirements.txt | grep -vE "bitsandbytes|flash-attn" ) examples/agents/run_sandboxed_tools.sh examples/agents/hermes_function_calling/functions.py -e LOG_FOLDER=/data/inference_logs
# git submodule add https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling
# python examples/openai/fastify.py examples/openai/hermes_function_calling/functions.py
# REQUIREMENTS_FILE=<( cat examples/openai/hermes_function_calling/requirements.txt | grep -vE "bitsandbytes|flash-attn" ) examples/agents/run_sandboxed_tools.sh examples/agents/hermes_function_calling/functions.py -e LOG_FOLDER=/data/inference_logs
set -euo pipefail
script="$( realpath "$1" )"

View file

@ -1,35 +1,39 @@
# https://gist.github.com/ochafik/a3d4a5b9e52390544b205f37fb5a0df3
# pip install "fastapi[all]" "uvicorn[all]" sse-starlette jsonargparse jinja2 pydantic
import json, sys, subprocess, atexit
from pathlib import Path
# sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
from examples.json_schema_to_grammar import SchemaConverter
from typing import Optional
import httpx
from fastapi import Depends, FastAPI, Request, Response
from starlette.responses import StreamingResponse
from fastapi.responses import JSONResponse
from jsonargparse import CLI
from examples.openai.ts_converter import SchemaToTypeScriptConverter
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.api import Message, Tool, ToolFunction, ResponseFormat, ChatCompletionRequest
from examples.openai.chat_format import ChatFormat, ToolStyle
from examples.openai.api import Message, ChatCompletionRequest
from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt
def _add_system_prompt(messages: list['Message'], system_prompt: str):
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import httpx
from starlette.responses import StreamingResponse
from typing import Annotated, Optional
import typer
from typeguard import typechecked
@typechecked
def _add_system_prompt(messages: list[Message], system_prompt: Message) -> list[Message]:
assert system_prompt.role == "system"
# TODO: add to last system message, or create a new one just before the last user message
system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None)
if system_message is not None:
(i, m) = system_message
messages[i].content = m.content + '\n' + system_prompt
return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:]
else:
messages.insert(0, Message(role="system", content=system_prompt))
return messages
return [Message(role="system", content=system_prompt)] + messages
def main(
model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
host: str = "localhost",
port: int = 8080,
main_server_endpoint: Optional[str] = None,
@ -48,7 +52,7 @@ def main(
"./server", "-m", model,
"--host", main_server_host, "--port", f'{main_server_port}',
'-ctk', 'q4_0', '-ctv', 'f16',
"-c", f"8192",
"-c", f"{2*8192}",
# "-c", f"{context_length}",
])
atexit.register(server_process.kill)
@ -70,110 +74,10 @@ def main(
response_schema = None
messages = chat_request.messages
parser=None
grammar=None
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
response_rule = converter.visit(response_schema, "response") if response_schema else None
delimiter = '<%$[SAMPLE]$%>'
empty_prompt = chat_format.render([], add_generation_prompt=True)
planted_prompt = chat_format.render([{"role": "assistant", "content": delimiter}], add_generation_prompt=False)
assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}"
[prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter)
if chat_request.tools:
if chat_format.tool_style in (ToolStyle.DEFAULT, ToolStyle.NOUS_RESEARCH_HERMES):
messages = _add_system_prompt(messages, '\n'.join([
'Here are the tools available:',
'<tools>',
*(tool.model_dump_json() for tool in chat_request.tools),
'</tools>',
]))
messages = _add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
tool_rules = [
converter.visit(
dict(
type="object",
properties=dict(
name=dict(const=tool.function.name),
arguments=tool.function.parameters,
),
required=['name', 'arguments']
),
f'{tool.function.name}-tool-call'
)
for tool in chat_request.tools
]
# Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
# OR a tool-call message respecting the schema of any of the tools
converter._add_rule(
"root",
converter._format_literal(prefix) + " (" +
(response_rule or converter.not_literal("<tool_call>")) + " | " +
converter._format_literal("<tool_call>") + " (" +
' | '.join(tool_rules) +
") " + converter._format_literal("</tool_call>") +
") " + converter._format_literal(suffix))
grammar = converter.format_grammar()
def parse(s: str):
if '<tool_call>'.startswith(s):
if s.startswith('<tool_call>') and s.endswith('</tool_call>' + suffix):
s = s[len('<tool_call>'):-len('</tool_call>' + suffix)]
return {"role": "assistant", "tool_calls": [json.loads(s)]}
return None
else:
return {"role": "assistant", "content": s}
parser = parse
elif chat_format.tool_style == ToolStyle.FUNCTIONARY_V2:
ts_converter = SchemaToTypeScriptConverter()
messages = _add_system_prompt(messages, '\n'.join([
'// Supported function definitions that should be called when necessary.'
'namespace functions {',
*[
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n"
for tool in chat_request.tools
],
'} // namespace functions',
]))
# Only allowing a single tool call at a time for now.
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal
converter._add_rule(
"root",
converter._format_literal(prefix) + " (" +
(response_rule or converter.not_literal("<|recipient|>")) + " | " +
(' | '.join(
converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
converter.visit(tool.function.parameters, tool.function.name + '-args')
for tool in chat_request.tools
)) +
") " +
") " + converter._format_literal(suffix))
grammar = converter.format_grammar()
else:
raise NotImplementedError(f'Unsupported tool_style: {chat_format.tool_style}')
elif response_schema:
converter._add_rule('root', response_rule)
grammar = converter.format_grammar()
def parse(s):
if s.endswith(suffix):
s = s[:-len(suffix)]
return {"role": "assistant", "content": s}
return None
parser = parse
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
if chat_format.strict_user_assistant_alternation:
print("TODO: merge system messages into user messages")
@ -182,11 +86,9 @@ def main(
# TODO: Test whether the template supports formatting tool_calls
prompt = chat_format.render(messages, add_generation_prompt=True)
# print(prompt)
# print(grammar)
print(json.dumps(dict(
prompt=prompt,
stream=chat_request.stream,
prompt=prompt,
grammar=grammar,
), indent=2))
async with httpx.AsyncClient() as client:
@ -195,14 +97,23 @@ def main(
json=LlamaCppServerCompletionRequest(
prompt=prompt,
stream=chat_request.stream,
n_predict=100,
n_predict=300,
grammar=grammar,
).model_dump(),
headers=headers,
timeout=None)
return StreamingResponse(generate_chunks(response), media_type="text/event-stream") if chat_request.stream \
else JSONResponse(response.json())
if chat_request.stream:
# TODO: Remove suffix from streamed response using partial parser.
assert not chat_request.tools and not chat_request.response_format, "Streaming not supported yet with tools or response_format"
return StreamingResponse(generate_chunks(response), media_type="text/event-stream")
else:
result = response.json()
print(json.dumps(result, indent=2))
message = parser(result["content"])
assert message is not None, f"Failed to parse response: {response.text}"
return JSONResponse(message.model_dump())
# return JSONResponse(response.json())
async def generate_chunks(response):
async for chunk in response.aiter_bytes():
@ -211,5 +122,4 @@ def main(
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
CLI(main)
typer.run(main)

79
examples/openai/test.sh Executable file
View file

@ -0,0 +1,79 @@
#!/bin/bash
set -euo pipefail
SERVER_PID=""
function cleanup() {
if [ -n "$SERVER_PID" ]; then
echo "# Killing server"
kill $SERVER_PID
wait $SERVER_PID
fi
}
trap cleanup EXIT
echo "# Starting the server"
python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf &
SERVER_PID=$!
sleep 5
echo "# Send a message to the chat API"
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $OPENAI_API_KEY" \
-d '{
"model": "gpt-3.5-turbo",
"tools": [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location."
}
},
"required": ["location", "format"]
}
}
}, {
"type": "function",
"function": {
"name": "get_n_day_weather_forecast",
"description": "Get an N-day weather forecast",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location."
},
"num_days": {
"type": "integer",
"description": "The number of days to forecast"
}
},
"required": ["location", "format", "num_days"]
}
}
}],
"messages": [
{"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
{"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days. Give the temperatyre in celsius for both locations."}
]
}'