agent/openai:nits

This commit is contained in:
ochafik 2024-03-29 17:00:53 +00:00
parent ce2fb0155f
commit ea34bd3e5c
10 changed files with 72 additions and 145 deletions

View file

@ -76,7 +76,7 @@ This example relies on the new [OpenAI compatibility server](../openai).
agent.py → examples.openai → server.cpp
→ safe_tools.py
→ ( run_sandboxed_tools.sh : Docker → fastify.py ) → unsafe_tools.py → code interpreter, etc...
```
```
The agent can use tools written in Python, or (soon) exposed under OpenAPI endpoints. Only has standard Python deps (e.g. no langchain)

View file

@ -128,7 +128,7 @@ def main(
max_iterations: Optional[int] = 10,
std_tools: Optional[bool] = False,
auth: Optional[str] = None,
allow_parallel_calls: Optional[bool] = False,
parallel_calls: Optional[bool] = True,
verbose: bool = False,
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
@ -174,14 +174,14 @@ def main(
"python", "-m", "examples.openai.server",
"--model", model,
*(['--verbose'] if verbose else []),
*(['--allow-parallel-calls'] if allow_parallel_calls else []),
*(['--parallel-calls'] if parallel_calls else []),
*(['--context-length={context_length}'] if context_length else []),
*([])
]
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
atexit.register(server_process.kill)
sleep(5)
tool_functions = []
types = {}
for f in tools:
@ -195,7 +195,7 @@ def main(
if std_tools:
tool_functions.extend(collect_functions(StandardTools))
response_model = None#str
if format:
if format in types:
@ -207,8 +207,8 @@ def main(
response_model = json.loads(format)
except:
response_model = eval(format)
result = completion_with_tool_usage(
model="...",
endpoint=endpoint,

View file

@ -41,4 +41,4 @@ def main(files: List[str], host: str = '0.0.0.0', port: int = 8000):
if __name__ == '__main__':
typer.run(main)

View file

@ -11,7 +11,7 @@ script="$( realpath "$1" )"
script_folder="$(dirname "$script")"
shift 1
function cleanup {
function cleanup {
rm -rf "$BUILD_DIR"
echo "Deleted $BUILD_DIR"
}

View file

@ -1,15 +1,11 @@
import atexit
from datetime import date
import datetime
from pydantic import BaseModel
import subprocess
import sys
from time import sleep
import time
import typer
from pydantic import BaseModel, Json, TypeAdapter
from annotated_types import MinLen
from typing import Annotated, Callable, List, Union, Literal, Optional, Type, get_args, get_origin
import json, requests
from typing import Union, Optional
class Duration(BaseModel):
seconds: Optional[int] = None
@ -50,7 +46,7 @@ class WaitForDate(BaseModel):
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {d}...\n")
time.sleep(days * 86400 + seconds)
sys.stderr.write(f"Reached the target date: {self.until}\n")
class StandardTools:
@ -61,7 +57,7 @@ class StandardTools:
This allows getting additional information, requesting disambiguation, etc.
'''
return typer.prompt(question)
@staticmethod
def wait(_for: Union[WaitForDuration, WaitForDate]) -> None:
'''
@ -69,7 +65,7 @@ class StandardTools:
This can be used to wait for a specific duration or until a specific date.
'''
return _for()
@staticmethod
def say_out_loud(something: str) -> str:
"""

View file

@ -34,7 +34,7 @@ The new [examples/openai/server.py](./server.py):
}
// Where T is the output JSON schema, or 'any'
```
- Option to publicise schemas to models as TypeScript signatures (as for Functionary) or JSON schema.
- Supports models that require user/assistant alternance (like Mixtral Instruct) by merging system messages into user messages.
@ -175,7 +175,7 @@ curl http://localhost:8080/v1/chat/completions \
- Evaluate options for session caching
- Pass session id & store / read from file?
- Support parent session ids for trees of thought?
- Support precaching long prompts from CLI / read session files?
@ -186,4 +186,4 @@ curl http://localhost:8080/v1/chat/completions \
- Remove non-Python json-schema-to-grammar versions
- Reach out to frameworks to advertise new option.
- Reach out to frameworks to advertise new option.

View file

@ -1,28 +1,12 @@
from typing import Optional
from pydantic import BaseModel, Json
from pydantic import Json
class LlamaCppServerCompletionRequest(BaseModel):
from examples.openai.api import LlamaCppParams
class LlamaCppServerCompletionRequest(LlamaCppParams):
prompt: str
stream: Optional[bool] = None
cache_prompt: Optional[bool] = None
n_predict: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
min_p: Optional[float] = None
tfs_z: Optional[float] = None
typical_p: Optional[float] = None
temperature: Optional[float] = None
dynatemp_range: Optional[float] = None
dynatemp_exponent: Optional[float] = None
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
mirostat: Optional[bool] = None
mirostat_tau: Optional[float] = None
mirostat_eta: Optional[float] = None
penalize_nl: Optional[bool] = None
n_keep: Optional[int] = None
seed: Optional[int] = None
grammar: Optional[str] = None
json_schema: Optional[Json] = None

View file

@ -1,15 +1,13 @@
from abc import ABC, abstractmethod
from enum import Enum
from functools import wraps
import jinja2
import json
from pathlib import Path
import random
import re
import sys
from typing import Any, Dict, Literal, Optional, Tuple, Callable, Union
from typing import Optional
from pydantic import BaseModel
# from typeguard import typechecked
from examples.json_schema_to_grammar import SchemaConverter
from examples.openai.api import Tool, Message, FunctionCall, ToolCall
@ -55,7 +53,7 @@ class ChatTemplate(BaseModel):
@property
def tool_style(self) -> 'ToolsPromptStyle':
return self._tool_style
def __init__(self, template: str, eos_token: str, bos_token: str):
super().__init__(template=template
)
@ -75,7 +73,7 @@ class ChatTemplate(BaseModel):
# self._tool_style = ToolsPromptStyle.TOOLS_MISTRAL
# TODO: Test whether the template supports formatting tool_calls
delimiter = '<%$[SAMPLE]$%>'
user_msg = Message(role="user", content="Hey")
empty_prompt = self.render([user_msg], add_generation_prompt=True).strip()
@ -112,7 +110,7 @@ class ChatTemplate(BaseModel):
def from_gguf(metadata: GGUFKeyValues):
if Keys.Tokenizer.CHAT_TEMPLATE not in metadata:
raise NotImplementedError(f'Only supporting models with {Keys.Tokenizer.CHAT_TEMPLATE} entry in their GGUF key-values (TODO: add default template, maybe pick llama2\'s?)')
tokens = metadata[Keys.Tokenizer.LIST]
return ChatTemplate(
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
@ -129,8 +127,6 @@ class ChatTemplate(BaseModel):
eos_token = tokenizer.eos_token)
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
# sys.stderr.write(f'# strict_user_assistant_alternation={self._strict_user_assistant_alternation}\n')
# sys.stderr.write(f'# messages=' + "\n".join(json.dumps(m.model_dump(), indent=2) for m in messages) + '\n')
if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
new_messages=[]
i = 0
@ -161,8 +157,7 @@ class ChatTemplate(BaseModel):
i += 1
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
messages = new_messages
# print(f'messages={messages}')
result = self._template.render(
messages=messages,
eos_token=self._eos_token,
@ -170,7 +165,6 @@ class ChatTemplate(BaseModel):
raise_exception=raise_exception,
add_generation_prompt=add_generation_prompt,
)
# sys.stderr.write(f'\n# RENDERED:\n\n{result}\n\n')
return result
class ChatHandlerArgs(BaseModel):
@ -192,7 +186,7 @@ class NoToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs):
super().__init__(args)
assert not args.tools
if args.response_schema:
self.output_format_prompt = Message(
role="system",
@ -206,21 +200,20 @@ class NoToolsChatHandler(ChatHandler):
self.output_format_prompt = None
self.grammar = None
# @typechecked
def parse(self, s: str) -> Optional[Message]:
return Message(role="assistant", content=s)
class ToolCallTagsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, allow_parallel_calls: bool):
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, parallel_calls: bool):
super().__init__(args)
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
tool_rules = []
for tool in self.args.tools:
parameters_schema = tool.function.parameters
parameters_schema = converter.resolve_refs(parameters_schema, tool.function.name)
tool_rules.append(converter.visit(
dict(
type="object",
@ -245,7 +238,7 @@ class ToolCallTagsChatHandler(ChatHandler):
format_literal("<tool_call>") + " space (" +
' | '.join(tool_rules) +
") space " + format_literal("</tool_call>"))# + ' space')
# Ideally we'd want a negative lookahead of /<tool\\?_call>/, but it's just too hard to express in GBNF for now.
# So we just over-constrain the content rule to not contain literals dangerously getting close to <tool_call>
content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "<t" [^o<]')
@ -253,22 +246,10 @@ class ToolCallTagsChatHandler(ChatHandler):
converter._add_rule(
'root',
# tool_call_rule)
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if parallel_calls \
else f'{content_rule}* {tool_call_rule}?')
self.grammar = converter.format_grammar()
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
# # OR a tool-call message respecting the schema of any of the tools
# converter._add_rule(
# "root",
# converter._format_literal(prefix) + " (" +
# (response_rule or converter.not_literal("<tool_call>")) + " | " +
# converter._format_literal("<tool_call>") + " (" +
# ' | '.join(tool_rules) +
# ") " + converter._format_literal("</tool_call>") +
# ")") # + converter._format_literal(suffix))
# @typechecked
def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s)
@ -294,21 +275,14 @@ class ToolCallTagsChatHandler(ChatHandler):
ToolCall(
id=gen_callid(),
function=FunctionCall(**fc)))
content = '\n'.join(content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
# return None
# else:
# return Message(role="assistant", content=s)
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs, template: str, escapes_underscores=False, allow_parallel_calls=True):
super().__init__(args, escapes_underscores=escapes_underscores, allow_parallel_calls=allow_parallel_calls)
def __init__(self, args: ChatHandlerArgs, template: str, parallel_calls: bool, escapes_underscores: bool = False):
super().__init__(args, escapes_underscores=escapes_underscores, parallel_calls=parallel_calls)
assert '{tools}' in template, 'Template must contain "{tools}"'
self.output_format_prompt = Message(
@ -320,8 +294,8 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
)
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
super().__init__(args, escapes_underscores=False, allow_parallel_calls=allow_parallel_calls)
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args, escapes_underscores=False, parallel_calls=parallel_calls)
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
path = str(Path(__file__).parent / "hermes_function_calling")
@ -330,15 +304,15 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
from examples.openai.hermes_function_calling.prompter import PromptManager
except ImportError:
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools])
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in args.tools])
assert len(prompt) == 1 and prompt[0]["role"] == "system"
self.output_format_prompt = Message(**prompt[0])
class FunctionaryToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args)
# Only allowing a single tool call at a time for now.
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal
@ -347,7 +321,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
content= '// Supported function definitions that should be called when necessary.\n' +
_tools_typescript_signatures(args.tools)
)
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
tool_rules = [
converter._add_rule(
@ -355,17 +329,6 @@ class FunctionaryToolsChatHandler(ChatHandler):
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
converter._format_literal('\n'))
# converter.visit(
# dict(
# type="object",
# properties=dict(
# name=dict(const=tool.function.name),
# arguments=tool.function.parameters,
# ),
# required=['name', 'arguments']
# ),
# f'{tool.function.name}-tool-call'
# )
for i, tool in enumerate(self.args.tools)
]
@ -378,33 +341,18 @@ class FunctionaryToolsChatHandler(ChatHandler):
tool_call_without_start_rule = converter._add_rule(
'tool_call_without_start',
' | '.join(tool_rules))
# + ' ' +
# converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
# converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
converter._add_rule(
'root',
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if parallel_calls \
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
self.grammar = converter.format_grammar()
# converter._add_rule(
# "root",
# converter._format_literal(prefix) + " (" +
# (response_rule or converter.not_literal("<|recipient|>")) + " | " +
# (' | '.join(
# converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
# converter.visit(tool.function.parameters, tool.function.name + '-args')
# for tool in tools
# )) +
# ") " +
# ")") # + converter._format_literal(suffix))
# @typechecked
def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s)
parts = _recipient_content_re.split(s)
if len(parts) == 1:
return Message(role="assistant", content=s)
@ -426,14 +374,14 @@ class FunctionaryToolsChatHandler(ChatHandler):
ToolCall(
id=gen_callid(),
function=FunctionCall(name=recipient, arguments=arguments)))
assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}'
content = '\n'.join(text_content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls):
def _make_bespoke_schema(response_schema, tool_call_schema, parallel_calls):
return {
"type": "object",
"properties": {
@ -453,7 +401,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
# "const": "tool_calls"
# },
"tool_calls": {
"prefixItems": tool_call_schema if allow_parallel_calls \
"prefixItems": tool_call_schema if parallel_calls \
else [tool_call_schema],
}
},
@ -474,9 +422,9 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
}
class BespokeToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args)
# args.response_schema = args.response_schema or {}
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
@ -497,7 +445,7 @@ class BespokeToolsChatHandler(ChatHandler):
for tool in self.args.tools
]
},
allow_parallel_calls=allow_parallel_calls,
parallel_calls=parallel_calls,
),
'',
)
@ -525,13 +473,12 @@ class BespokeToolsChatHandler(ChatHandler):
},
"required": ["name", "arguments"]
},
allow_parallel_calls=allow_parallel_calls,
parallel_calls=parallel_calls,
)
),
])
)
# @typechecked
def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s)
try:
@ -579,19 +526,19 @@ _LONG_TEMPLATE='\n'.join([
# 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
])
def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatHandler:
def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool) -> ChatHandler:
if not args.tools:
return NoToolsChatHandler(args)
elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
return FunctionaryToolsChatHandler(args, allow_parallel_calls=False)
return FunctionaryToolsChatHandler(args, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT:
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, parallel_calls=parallel_calls, escapes_underscores=True)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
return BespokeToolsChatHandler(args, allow_parallel_calls=allow_parallel_calls)
return BespokeToolsChatHandler(args, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
return Hermes2ProToolsChatHandler(args)
else:

View file

@ -31,7 +31,7 @@ def main(
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
host: str = "localhost",
port: int = 8080,
allow_parallel_calls: Optional[bool] = False,
parallel_calls: Optional[bool] = True,
auth: Optional[str] = None,
verbose: bool = False,
context_length: Optional[int] = None,
@ -44,13 +44,13 @@ def main(
if endpoint:
sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
else:
metadata = GGUFKeyValues(model)
if not context_length:
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
if Keys.Tokenizer.CHAT_TEMPLATE in metadata:
chat_template = ChatTemplate.from_gguf(metadata)
else:
@ -92,22 +92,22 @@ def main(
chat_handler = get_chat_handler(
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
allow_parallel_calls=allow_parallel_calls
parallel_calls=parallel_calls
)
messages = chat_request.messages
if chat_handler.output_format_prompt:
messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
prompt = chat_template.render(messages, add_generation_prompt=True)
if verbose:
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')
# sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n')
sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n')
sys.stderr.write(f'\n# GRAMMAR:\n\n{chat_handler.grammar}\n\n')
data = LlamaCppServerCompletionRequest(
**{
k: v
@ -130,7 +130,7 @@ def main(
json=data,
headers=headers,
timeout=None)
if chat_request.stream:
# TODO: Remove suffix from streamed response using partial parser.
assert not chat_request.tools and not chat_request.response_format, "Streaming not supported yet with tools or response_format"

View file

@ -31,7 +31,7 @@ class SchemaToTypeScriptConverter:
[f"{self._desc_comment(additional_properties) if additional_properties else ''}[key: string]: {self.visit(additional_properties)}"]
if additional_properties is not None else []
)) + "}"
def visit(self, schema: dict):
def print_constant(v):
return json.dumps(v)