diff --git a/.gitmodules b/.gitmodules index 9d262566c..b7e8b8ff2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ [submodule "kompute"] path = kompute url = https://github.com/nomic-ai/kompute.git -[submodule "examples/agents/hermes_function_calling"] - path = examples/agents/hermes_function_calling - url = https://github.com/NousResearch/Hermes-Function-Calling diff --git a/examples/agents/README.md b/examples/agents/README.md deleted file mode 100644 index eb743d0c1..000000000 --- a/examples/agents/README.md +++ /dev/null @@ -1,15 +0,0 @@ - -Edit `examples/agents/hermes_function_calling/utils.py`: - -```py -log_folder = os.environ.get('LOG_FOLDER', os.path.join(script_dir, "inference_logs")) -``` - -Then run: - -```bash -REQUIREMENTS_FILE=<( cat examples/agents/hermes_function_calling/requirements.txt | grep -vE "bitsandbytes|flash-attn" ) \ - examples/agents/run_sandboxed_tools.sh \ - examples/agents/hermes_function_calling/functions.py \ - -e LOG_FOLDER=/data/inference_logs -``` \ No newline at end of file diff --git a/examples/agents/hermes_function_calling b/examples/agents/hermes_function_calling deleted file mode 160000 index b4f757e27..000000000 --- a/examples/agents/hermes_function_calling +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b4f757e27d87f4ab408f706f482c25a8e1508d59 diff --git a/examples/agents/requirements.txt b/examples/agents/requirements.txt deleted file mode 100644 index 2ff0ce927..000000000 --- a/examples/agents/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -jsonargparse -pydantic -typer[all] \ No newline at end of file diff --git a/examples/openai/README.md b/examples/openai/README.md index 47c9c67cc..078b52bba 100644 --- a/examples/openai/README.md +++ b/examples/openai/README.md @@ -1,15 +1,45 @@ -# examples.openai: OpenAI API-compatible server +# examples.openai: OpenAI API-compatible server + agent / tools examples A simple Python server that sits above the C++ [../server](examples/server) and offers improved OAI compatibility. ## Usage +Run a simple test: + ```bash -python -m examples.openai -m some-model.gguf - - +# Spawns a Python server (which spawns a C++ Server) then hits it w/ a tool-calling request +examples/openai/test.sh ``` +To simply run the Python server (+ C++ server under the hood): + +```bash +python -m examples.openai +``` + +## Tools usage (WIP) + +```bash +git clone https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling +``` + +Then edit `examples/agents/hermes_function_calling/utils.py`: + +```py +log_folder = os.environ.get('LOG_FOLDER', os.path.join(script_dir, "inference_logs")) +``` + +Then run tools in a sandbox: + +```bash +REQUIREMENTS_FILE=<( cat examples/agents/hermes_function_calling/requirements.txt | grep -vE "bitsandbytes|flash-attn" ) \ + examples/agents/run_sandboxed_tools.sh \ + examples/agents/hermes_function_calling/functions.py \ + -e LOG_FOLDER=/data/inference_logs +``` + +TODO: reactor that reads OpenAPI definitions and does the tool calling + ## Features The new examples/openai/server.py: diff --git a/examples/openai/api.py b/examples/openai/api.py index b883ecec4..c44c6bfd1 100644 --- a/examples/openai/api.py +++ b/examples/openai/api.py @@ -1,9 +1,14 @@ -from typing import Any, Optional +from typing import Any, Dict, Optional from pydantic import BaseModel, Json +class ToolCall(BaseModel): + name: str + arguments: Dict[str, Any] + class Message(BaseModel): role: str - content: str + content: Optional[str] + tool_calls: Optional[list[ToolCall]] = None class ToolFunction(BaseModel): name: str diff --git a/examples/openai/chat_format.py b/examples/openai/chat_format.py deleted file mode 100644 index bb7d0c94c..000000000 --- a/examples/openai/chat_format.py +++ /dev/null @@ -1,59 +0,0 @@ -from enum import StrEnum -import jinja2 - -from examples.openai.gguf_kvs import GGUFKeyValues, Keys - -def raise_exception(msg: str): - raise Exception(msg) - -class ToolStyle(StrEnum): - # https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models - DEFAULT="Default", - # https://github.com/MeetKai/functionary - # TODO: look at https://github.com/ggerganov/llama.cpp/pull/5695 - # https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py - FUNCTIONARY_V2="Functionary V2", - # https://github.com/NousResearch/Hermes-Function-Calling - NOUS_RESEARCH_HERMES="Nous-Research-Hermes-Function-Calling", - -class ChatFormat: #(BaseModel): - def __init__(self, template: str, eos_token: str, bos_token: str): - env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True) - self.template = env.from_string(template) - self.eos_token = eos_token - self.bos_token = bos_token - - self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template - - if "<|recipient|>' + tool_call['function']['name']" in template: - self.tool_style = ToolStyle.FUNCTIONARY_V2 - else: - self.tool_style = ToolStyle.DEFAULT - - - def __str__(self): - return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})" - - - @staticmethod - def from_gguf(metadata: GGUFKeyValues): - return ChatFormat( - template = metadata[Keys.Tokenizer.CHAT_TEMPLATE], - bos_token = metadata[Keys.Tokenizer.BOS_ID], - eos_token = metadata[Keys.Tokenizer.EOS_ID]) - # @staticmethod - # def from_gguf(model: Path): - # reader = GGUFReader(model.as_posix()) - # return ChatFormat( - # template = reader.fields[Keys.Tokenizer.CHAT_TEMPLATE].read(), - # bos_token = reader.fields[Keys.Tokenizer.BOS_ID].read(), - # eos_token = reader.fields[Keys.Tokenizer.EOS_ID].read()) - - def render(self, messages: list[dict], add_generation_prompt: bool, omit_bos: bool = False): - return self.template.render( - messages=messages, - eos_token=self.eos_token, - bos_token='' if omit_bos else self.bos_token, - raise_exception=raise_exception, - add_generation_prompt=add_generation_prompt, - ) diff --git a/examples/agents/fastify-requirements.txt b/examples/openai/fastify-requirements.txt similarity index 100% rename from examples/agents/fastify-requirements.txt rename to examples/openai/fastify-requirements.txt diff --git a/examples/agents/fastify.py b/examples/openai/fastify.py similarity index 97% rename from examples/agents/fastify.py rename to examples/openai/fastify.py index 48df2bfda..8846a3823 100644 --- a/examples/agents/fastify.py +++ b/examples/openai/fastify.py @@ -8,8 +8,6 @@ from anyio import Path import fastapi, uvicorn import typer -# from langchain_core.tools import BaseTool - def load_source_as_module(source): i = 0 while (module_name := f'mod_{i}') in sys.modules: diff --git a/examples/openai/prompt1.txt b/examples/openai/prompt1.txt new file mode 100644 index 000000000..afae47380 --- /dev/null +++ b/examples/openai/prompt1.txt @@ -0,0 +1,43 @@ +<|im_start|>system +Role: + You are a function calling AI agent with self-recursion. + You can call only one function at a time and analyse data you get from function response. + You are provided with function signatures within XML tags. + The current date is: March 25, 2024. + +Objective: + You may use agentic frameworks for reasoning and planning to help with user query. + Please call a function and wait for function results to be provided to you in the next iteration. + Don't make assumptions about what values to plug into function arguments. + Once you have called a function, results will be fed back to you within XML tags. + Don't make assumptions about tool results if XML tags are not present since function hasn't been executed yet. + Analyze the data once you get the results and call another function. + At each iteration please continue adding the your analysis to previous summary. + Your final response should directly answer the user query with an anlysis or summary of the results of function calls. + +Tools: + Here are the available tools: + + {"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"format":{"type":"string","enum":["celsius","fahrenheit"],"description":"The temperature unit to use. Infer this from the users location."}},"required":["location","format"]}}} + {"type":"function","function":{"name":"get_n_day_weather_forecast","description":"Get an N-day weather forecast","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"format":{"type":"string","enum":["celsius","fahrenheit"],"description":"The temperature unit to use. Infer this from the users location."},"num_days":{"type":"integer","description":"The number of days to forecast"}},"required":["location","format","num_days"]}}} + + If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows: + + {"arguments": {"code_markdown": , "name": "code_interpreter"}} + + Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree. + +Instructions: + At the very first turn you don't have so you shouldn't not make up the results. + Please keep a running summary with analysis of previous function results and summaries from previous iterations. + Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10. + Calling multiple functions at once can overload the system and increase cost so call one function at a time please. + If you plan to continue with analysis, always call another function. + For each function call return a valid json object (using doulbe quotes) with function name and arguments within XML tags as follows: + + {"arguments": , "name": } + +<|im_end|> +<|im_start|>user +what is the weather going to be like in San Francisco and Glasgow over the next 4 days (temperature in celsius for both)<|im_end|> +<|im_start|>assistant \ No newline at end of file diff --git a/examples/openai/prompting.py b/examples/openai/prompting.py new file mode 100644 index 000000000..71912fed5 --- /dev/null +++ b/examples/openai/prompting.py @@ -0,0 +1,242 @@ +from enum import Enum +import jinja2 +import json +from pathlib import Path +import sys +from typing import Optional, Tuple, Callable +from typeguard import typechecked + +from examples.json_schema_to_grammar import SchemaConverter +from examples.openai.api import Tool, Message +from examples.openai.gguf_kvs import GGUFKeyValues, Keys +from examples.openai.ts_converter import SchemaToTypeScriptConverter + +@typechecked +def raise_exception(msg: str): + raise Exception(msg) + +class ChatFormat: + def __init__(self, template: str, eos_token: str, bos_token: str): + env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True) + self.template = env.from_string(template) + self.eos_token = eos_token + self.bos_token = bos_token + + self.strict_user_assistant_alternation = "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception" in template + + if "<|recipient|>' + tool_call['function']['name']" in template: + self.tool_style = ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2 + else: + self.tool_style = ToolsPromptStyle.TOOLS_LONG + + def __str__(self): + return f"ChatFormat(template={self.template}, eos_token={self.eos_token}, bos_token={self.bos_token})" + + @staticmethod + def from_gguf(metadata: GGUFKeyValues): + return ChatFormat( + template = metadata[Keys.Tokenizer.CHAT_TEMPLATE], + bos_token = metadata[Keys.Tokenizer.BOS_ID], + eos_token = metadata[Keys.Tokenizer.EOS_ID]) + + def render(self, messages: list[dict], add_generation_prompt: bool, omit_bos: bool = False): + return self.template.render( + messages=messages, + eos_token=self.eos_token, + bos_token='' if omit_bos else self.bos_token, + raise_exception=raise_exception, + add_generation_prompt=add_generation_prompt, + ) + +# While the API will be usable with a generic tools usage like OpenAI, +# (see https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models), +# each model may need specific prompting (and/or constrained output, +# especially for models not fine-tuned for tool usage / function calling). +class ToolsPromptStyle(Enum): + # Short prompt w/ schemas + TOOLS_SHORT = 1 + + # Longer prompt w/ schemas + TOOLS_LONG = 2 + + # Large prompt for https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B + # Requires: + # - git clone https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling + # - Set large context length as their prompts are super long + TOOLS_HERMES_2_PRO = 3 + + # Short prompt w/ TypeScript definitions for https://github.com/MeetKai/functionary + # https://github.com/MeetKai/functionary/blob/main/functionary/prompt_template/prompt_template_v2.py + # Note: see this prior attempt to support Functionary: https://github.com/ggerganov/llama.cpp/pull/5695 + TYPESCRIPT_FUNCTIONARY_V2 = 4 + +@typechecked +def make_tools_prompt(chat_format: ChatFormat, tools: list[Tool], indent=2) -> Message: + + if chat_format.tool_style == ToolsPromptStyle.TOOLS_SHORT: + return Message( + role="system", + content='\n'.join([ + 'Here are the tools available:', + '', + *(json.dumps(tool.model_dump(), indent=indent) for tool in tools), + '', + ]) + ) + + elif chat_format.tool_style == ToolsPromptStyle.TOOLS_LONG: + return Message( + role="system", + content='\n'.join([ + '''You are a function calling AI model. You are provided with function signatures within XML tags.''', + '''You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:''', + '''''', + *(json.dumps(tool.model_dump(), indent=indent) for tool in tools), + '''''', + '', + '''Use the following json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}''', + '', + '''For each function call return a json object with function name and arguments within XML tags as follows:''', + '''''', + '''{"arguments": , "name": }''', + '''''', + ]) + ) + + elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2: + ts_converter = SchemaToTypeScriptConverter() + + return Message( + role="system", + content='\n'.join([ + '// Supported function definitions that should be called when necessary.' + 'namespace functions {', + *[ + '// ' + tool.function.description.replace('\n', '\n// ') + '\n' + '' + 'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n" + for tool in tools + ], + '} // namespace functions', + ]) + ) + + elif chat_format.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO: + # Hackily import https://github.com/NousResearch/Hermes-Function-Calling + path = str(Path(__file__).parent / "hermes_function_calling") + if path not in sys.path: sys.path.insert(0, path) + try: + from examples.openai.hermes_function_calling.prompter import PromptManager + except ImportError: + raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`") + + prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools]) + assert len(prompt) == 1 and prompt[0]["role"] == "system" + return Message(**prompt[0]) + + else: + raise ValueError(f"Unsupported tool call style: {chat_format.tool_style}") + +@typechecked +def _outputs_tool_call_tags(style: ToolsPromptStyle) -> bool: + return style in ( + ToolsPromptStyle.TOOLS_SHORT, + ToolsPromptStyle.TOOLS_LONG, + ToolsPromptStyle.TOOLS_HERMES_2_PRO, + ) + +@typechecked +def make_grammar(chat_format: ChatFormat, tools: list[Tool], response_schema: Optional[dict], indent=2) -> Tuple[Optional[str], Callable[[str], Optional[Message]]]: + + converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) + + response_rule = converter.visit(response_schema, "response") if response_schema else None + + delimiter = '<%$[SAMPLE]$%>' + empty_prompt = chat_format.render([], add_generation_prompt=True).strip() + planted_prompt = chat_format.render([{"role": "assistant", "content": delimiter}], add_generation_prompt=False).strip() + assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}" + [prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter) + + if tools: + if _outputs_tool_call_tags(chat_format.tool_style): + tool_rules = [ + converter.visit( + dict( + type="object", + properties=dict( + name=dict(const=tool.function.name), + arguments=tool.function.parameters, + ), + required=['name', 'arguments'] + ), + f'{tool.function.name}-tool-call' + ) + for tool in tools + ] + + # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not) + # OR a tool-call message respecting the schema of any of the tools + converter._add_rule( + "root", + converter._format_literal(prefix) + " (" + + (response_rule or converter.not_literal("")) + " | " + + converter._format_literal("") + " (" + + ' | '.join(tool_rules) + + ") " + converter._format_literal("") + + ")") # + converter._format_literal(suffix)) + + @typechecked + def parse(s: str) -> Optional[Message]: + ls = s.lstrip() + if ''.startswith(ls) or ls.startswith(''): + if ls.startswith('') and ls.endswith('' + suffix): + tool_call = ls[len(''):-len('' + suffix)] + return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)]) + return None + else: + return Message(role="assistant", content=s) + + return (converter.format_grammar(), parse) + + elif chat_format.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2: + # Only allowing a single tool call at a time for now. + # Note that if there were more, they'd be separated by a '<|from|>assistant' literal + converter._add_rule( + "root", + converter._format_literal(prefix) + " (" + + (response_rule or converter.not_literal("<|recipient|>")) + " | " + + (' | '.join( + converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " + + converter.visit(tool.function.parameters, tool.function.name + '-args') + for tool in tools + )) + + ") " + + ")") # + converter._format_literal(suffix)) + + @typechecked + def parse(s: str) -> Optional[Message]: + raise NotImplementedError(f'TODO: parse tool_style {chat_format.tool_style}: {s}') + + return (converter.format_grammar(), parse) + + elif response_schema: + converter._add_rule("root", response_rule + ' ' + converter._format_literal(suffix)) + + @typechecked + def parse(s: str) -> Optional[Message]: + if response_rule.endswith(suffix): + return Message(role="assistant", content=s[:-len(suffix)]) + + return (converter.format_grammar(), parse) + + else: + converter._add_rule("root", converter._format_literal(prefix) + ' ' + converter._format_literal(suffix)) + + @typechecked + def parse(s: str) -> Optional[Message]: + if s.endswith(suffix): + return Message(role="assistant", content=s[:-len(suffix)]) + return None + + return (None, parse) + diff --git a/examples/openai/requirements.txt b/examples/openai/requirements.txt index 219fda417..b092bf19f 100644 --- a/examples/openai/requirements.txt +++ b/examples/openai/requirements.txt @@ -1,7 +1,7 @@ fastapi[all] gguf jinja2 -jsonargparse pydantic sse-starlette -uvicorn[all] \ No newline at end of file +uvicorn[all] +typer[all] \ No newline at end of file diff --git a/examples/agents/run_sandboxed_tools.sh b/examples/openai/run_sandboxed_tools.sh similarity index 93% rename from examples/agents/run_sandboxed_tools.sh rename to examples/openai/run_sandboxed_tools.sh index eb8eb252e..88e61f568 100755 --- a/examples/agents/run_sandboxed_tools.sh +++ b/examples/openai/run_sandboxed_tools.sh @@ -2,9 +2,9 @@ # # Runs a Python script in a sandboxed environment and makes its functions available as a web service. # -# git submodule add https://github.com/NousResearch/Hermes-Function-Calling examples/agents/hermes_function_calling -# python examples/agents/fastify.py examples/agents/hermes_function_calling/functions.py -# REQUIREMENTS_FILE=<( cat examples/agents/hermes_function_calling/requirements.txt | grep -vE "bitsandbytes|flash-attn" ) examples/agents/run_sandboxed_tools.sh examples/agents/hermes_function_calling/functions.py -e LOG_FOLDER=/data/inference_logs +# git submodule add https://github.com/NousResearch/Hermes-Function-Calling examples/openai/hermes_function_calling +# python examples/openai/fastify.py examples/openai/hermes_function_calling/functions.py +# REQUIREMENTS_FILE=<( cat examples/openai/hermes_function_calling/requirements.txt | grep -vE "bitsandbytes|flash-attn" ) examples/agents/run_sandboxed_tools.sh examples/agents/hermes_function_calling/functions.py -e LOG_FOLDER=/data/inference_logs set -euo pipefail script="$( realpath "$1" )" diff --git a/examples/openai/server.py b/examples/openai/server.py index db075bddb..1a639e3a6 100644 --- a/examples/openai/server.py +++ b/examples/openai/server.py @@ -1,35 +1,39 @@ +# https://gist.github.com/ochafik/a3d4a5b9e52390544b205f37fb5a0df3 +# pip install "fastapi[all]" "uvicorn[all]" sse-starlette jsonargparse jinja2 pydantic + import json, sys, subprocess, atexit from pathlib import Path -# sys.path.insert(0, str(Path(__file__).parent.parent)) +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest -from examples.json_schema_to_grammar import SchemaConverter - -from typing import Optional -import httpx -from fastapi import Depends, FastAPI, Request, Response -from starlette.responses import StreamingResponse -from fastapi.responses import JSONResponse -from jsonargparse import CLI - -from examples.openai.ts_converter import SchemaToTypeScriptConverter from examples.openai.gguf_kvs import GGUFKeyValues, Keys -from examples.openai.api import Message, Tool, ToolFunction, ResponseFormat, ChatCompletionRequest -from examples.openai.chat_format import ChatFormat, ToolStyle +from examples.openai.api import Message, ChatCompletionRequest +from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt -def _add_system_prompt(messages: list['Message'], system_prompt: str): +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +import httpx +from starlette.responses import StreamingResponse +from typing import Annotated, Optional +import typer +from typeguard import typechecked + +@typechecked +def _add_system_prompt(messages: list[Message], system_prompt: Message) -> list[Message]: + assert system_prompt.role == "system" # TODO: add to last system message, or create a new one just before the last user message system_message = next(((i, m) for i, m in enumerate(messages) if m.role == "system"), None) if system_message is not None: (i, m) = system_message - messages[i].content = m.content + '\n' + system_prompt + return messages[:i] + [Message(role="system", content=m.content + '\n' + system_prompt.content)] + messages[i+1:] else: - messages.insert(0, Message(role="system", content=system_prompt)) - return messages + return [Message(role="system", content=system_prompt)] + messages def main( - model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"), + model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", + # model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"), + # model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None, host: str = "localhost", port: int = 8080, main_server_endpoint: Optional[str] = None, @@ -48,7 +52,7 @@ def main( "./server", "-m", model, "--host", main_server_host, "--port", f'{main_server_port}', '-ctk', 'q4_0', '-ctv', 'f16', - "-c", f"8192", + "-c", f"{2*8192}", # "-c", f"{context_length}", ]) atexit.register(server_process.kill) @@ -70,110 +74,10 @@ def main( response_schema = None messages = chat_request.messages - parser=None - grammar=None - - converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) - - response_rule = converter.visit(response_schema, "response") if response_schema else None - - - delimiter = '<%$[SAMPLE]$%>' - empty_prompt = chat_format.render([], add_generation_prompt=True) - planted_prompt = chat_format.render([{"role": "assistant", "content": delimiter}], add_generation_prompt=False) - assert planted_prompt.startswith(empty_prompt), f"Planted prompt does not start with empty prompt: {planted_prompt} vs {empty_prompt}" - [prefix, suffix] = planted_prompt[len(empty_prompt):].split(delimiter) - if chat_request.tools: - if chat_format.tool_style in (ToolStyle.DEFAULT, ToolStyle.NOUS_RESEARCH_HERMES): - messages = _add_system_prompt(messages, '\n'.join([ - 'Here are the tools available:', - '', - *(tool.model_dump_json() for tool in chat_request.tools), - '', - ])) + messages = _add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools)) - tool_rules = [ - converter.visit( - dict( - type="object", - properties=dict( - name=dict(const=tool.function.name), - arguments=tool.function.parameters, - ), - required=['name', 'arguments'] - ), - f'{tool.function.name}-tool-call' - ) - for tool in chat_request.tools - ] - - # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not) - # OR a tool-call message respecting the schema of any of the tools - converter._add_rule( - "root", - converter._format_literal(prefix) + " (" + - (response_rule or converter.not_literal("")) + " | " + - converter._format_literal("") + " (" + - ' | '.join(tool_rules) + - ") " + converter._format_literal("") + - ") " + converter._format_literal(suffix)) - grammar = converter.format_grammar() - - def parse(s: str): - if ''.startswith(s): - if s.startswith('') and s.endswith('' + suffix): - s = s[len(''):-len('' + suffix)] - return {"role": "assistant", "tool_calls": [json.loads(s)]} - return None - else: - return {"role": "assistant", "content": s} - - parser = parse - - elif chat_format.tool_style == ToolStyle.FUNCTIONARY_V2: - - ts_converter = SchemaToTypeScriptConverter() - - messages = _add_system_prompt(messages, '\n'.join([ - '// Supported function definitions that should be called when necessary.' - 'namespace functions {', - *[ - '// ' + tool.function.description.replace('\n', '\n// ') + '\n' + '' - 'type ' + tool.function.name + ' = (_: ' + ts_converter.visit(tool.function.parameters) + ") => any;\n" - for tool in chat_request.tools - ], - '} // namespace functions', - ])) - - # Only allowing a single tool call at a time for now. - # Note that if there were more, they'd be separated by a '<|from|>assistant' literal - converter._add_rule( - "root", - converter._format_literal(prefix) + " (" + - (response_rule or converter.not_literal("<|recipient|>")) + " | " + - (' | '.join( - converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " + - converter.visit(tool.function.parameters, tool.function.name + '-args') - for tool in chat_request.tools - )) + - ") " + - ") " + converter._format_literal(suffix)) - grammar = converter.format_grammar() - else: - raise NotImplementedError(f'Unsupported tool_style: {chat_format.tool_style}') - - elif response_schema: - converter._add_rule('root', response_rule) - grammar = converter.format_grammar() - - def parse(s): - if s.endswith(suffix): - s = s[:-len(suffix)] - return {"role": "assistant", "content": s} - return None - - parser = parse + (grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema) if chat_format.strict_user_assistant_alternation: print("TODO: merge system messages into user messages") @@ -182,11 +86,9 @@ def main( # TODO: Test whether the template supports formatting tool_calls prompt = chat_format.render(messages, add_generation_prompt=True) - # print(prompt) - # print(grammar) print(json.dumps(dict( - prompt=prompt, stream=chat_request.stream, + prompt=prompt, grammar=grammar, ), indent=2)) async with httpx.AsyncClient() as client: @@ -195,14 +97,23 @@ def main( json=LlamaCppServerCompletionRequest( prompt=prompt, stream=chat_request.stream, - n_predict=100, + n_predict=300, grammar=grammar, ).model_dump(), headers=headers, timeout=None) - return StreamingResponse(generate_chunks(response), media_type="text/event-stream") if chat_request.stream \ - else JSONResponse(response.json()) + if chat_request.stream: + # TODO: Remove suffix from streamed response using partial parser. + assert not chat_request.tools and not chat_request.response_format, "Streaming not supported yet with tools or response_format" + return StreamingResponse(generate_chunks(response), media_type="text/event-stream") + else: + result = response.json() + print(json.dumps(result, indent=2)) + message = parser(result["content"]) + assert message is not None, f"Failed to parse response: {response.text}" + return JSONResponse(message.model_dump()) + # return JSONResponse(response.json()) async def generate_chunks(response): async for chunk in response.aiter_bytes(): @@ -211,5 +122,4 @@ def main( uvicorn.run(app, host=host, port=port) if __name__ == "__main__": - CLI(main) - + typer.run(main) diff --git a/examples/openai/test.sh b/examples/openai/test.sh new file mode 100755 index 000000000..5fffc54d4 --- /dev/null +++ b/examples/openai/test.sh @@ -0,0 +1,79 @@ +#!/bin/bash +set -euo pipefail + +SERVER_PID="" +function cleanup() { + if [ -n "$SERVER_PID" ]; then + echo "# Killing server" + kill $SERVER_PID + wait $SERVER_PID + fi +} +trap cleanup EXIT + +echo "# Starting the server" +python -m examples.openai --model ~/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf & +SERVER_PID=$! + +sleep 5 + +echo "# Send a message to the chat API" + +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "gpt-3.5-turbo", + "tools": [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": ["location", "format"] + } + } + }, { + "type": "function", + "function": { + "name": "get_n_day_weather_forecast", + "description": "Get an N-day weather forecast", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + }, + "num_days": { + "type": "integer", + "description": "The number of days to forecast" + } + }, + "required": ["location", "format", "num_days"] + } + } + }], + "messages": [ + {"role": "system", "content": "Do not make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."}, + {"role": "user", "content": "what is the weather going to be like in San Francisco and Glasgow over the next 4 days. Give the temperatyre in celsius for both locations."} + ] + }' +