agent/openai:nits

This commit is contained in:
ochafik 2024-03-29 17:00:53 +00:00
parent ce2fb0155f
commit ea34bd3e5c
10 changed files with 72 additions and 145 deletions

View file

@ -128,7 +128,7 @@ def main(
max_iterations: Optional[int] = 10, max_iterations: Optional[int] = 10,
std_tools: Optional[bool] = False, std_tools: Optional[bool] = False,
auth: Optional[str] = None, auth: Optional[str] = None,
allow_parallel_calls: Optional[bool] = False, parallel_calls: Optional[bool] = True,
verbose: bool = False, verbose: bool = False,
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
@ -174,7 +174,7 @@ def main(
"python", "-m", "examples.openai.server", "python", "-m", "examples.openai.server",
"--model", model, "--model", model,
*(['--verbose'] if verbose else []), *(['--verbose'] if verbose else []),
*(['--allow-parallel-calls'] if allow_parallel_calls else []), *(['--parallel-calls'] if parallel_calls else []),
*(['--context-length={context_length}'] if context_length else []), *(['--context-length={context_length}'] if context_length else []),
*([]) *([])
] ]

View file

@ -1,15 +1,11 @@
import atexit
from datetime import date from datetime import date
import datetime import datetime
from pydantic import BaseModel
import subprocess import subprocess
import sys import sys
from time import sleep
import time import time
import typer import typer
from pydantic import BaseModel, Json, TypeAdapter from typing import Union, Optional
from annotated_types import MinLen
from typing import Annotated, Callable, List, Union, Literal, Optional, Type, get_args, get_origin
import json, requests
class Duration(BaseModel): class Duration(BaseModel):
seconds: Optional[int] = None seconds: Optional[int] = None

View file

@ -1,28 +1,12 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel, Json from pydantic import Json
class LlamaCppServerCompletionRequest(BaseModel): from examples.openai.api import LlamaCppParams
class LlamaCppServerCompletionRequest(LlamaCppParams):
prompt: str prompt: str
stream: Optional[bool] = None stream: Optional[bool] = None
cache_prompt: Optional[bool] = None cache_prompt: Optional[bool] = None
n_predict: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
min_p: Optional[float] = None
tfs_z: Optional[float] = None
typical_p: Optional[float] = None
temperature: Optional[float] = None
dynatemp_range: Optional[float] = None
dynatemp_exponent: Optional[float] = None
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
mirostat: Optional[bool] = None
mirostat_tau: Optional[float] = None
mirostat_eta: Optional[float] = None
penalize_nl: Optional[bool] = None
n_keep: Optional[int] = None
seed: Optional[int] = None
grammar: Optional[str] = None grammar: Optional[str] = None
json_schema: Optional[Json] = None json_schema: Optional[Json] = None

View file

@ -1,15 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from functools import wraps
import jinja2 import jinja2
import json import json
from pathlib import Path from pathlib import Path
import random import random
import re import re
import sys import sys
from typing import Any, Dict, Literal, Optional, Tuple, Callable, Union from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
# from typeguard import typechecked
from examples.json_schema_to_grammar import SchemaConverter from examples.json_schema_to_grammar import SchemaConverter
from examples.openai.api import Tool, Message, FunctionCall, ToolCall from examples.openai.api import Tool, Message, FunctionCall, ToolCall
@ -129,8 +127,6 @@ class ChatTemplate(BaseModel):
eos_token = tokenizer.eos_token) eos_token = tokenizer.eos_token)
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False): def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
# sys.stderr.write(f'# strict_user_assistant_alternation={self._strict_user_assistant_alternation}\n')
# sys.stderr.write(f'# messages=' + "\n".join(json.dumps(m.model_dump(), indent=2) for m in messages) + '\n')
if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages): if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
new_messages=[] new_messages=[]
i = 0 i = 0
@ -161,7 +157,6 @@ class ChatTemplate(BaseModel):
i += 1 i += 1
# print(f'new_messages={json.dumps(new_messages, indent=2)}') # print(f'new_messages={json.dumps(new_messages, indent=2)}')
messages = new_messages messages = new_messages
# print(f'messages={messages}')
result = self._template.render( result = self._template.render(
messages=messages, messages=messages,
@ -170,7 +165,6 @@ class ChatTemplate(BaseModel):
raise_exception=raise_exception, raise_exception=raise_exception,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
) )
# sys.stderr.write(f'\n# RENDERED:\n\n{result}\n\n')
return result return result
class ChatHandlerArgs(BaseModel): class ChatHandlerArgs(BaseModel):
@ -206,12 +200,11 @@ class NoToolsChatHandler(ChatHandler):
self.output_format_prompt = None self.output_format_prompt = None
self.grammar = None self.grammar = None
# @typechecked
def parse(self, s: str) -> Optional[Message]: def parse(self, s: str) -> Optional[Message]:
return Message(role="assistant", content=s) return Message(role="assistant", content=s)
class ToolCallTagsChatHandler(ChatHandler): class ToolCallTagsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, allow_parallel_calls: bool): def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, parallel_calls: bool):
super().__init__(args) super().__init__(args)
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
@ -253,22 +246,10 @@ class ToolCallTagsChatHandler(ChatHandler):
converter._add_rule( converter._add_rule(
'root', 'root',
# tool_call_rule) # tool_call_rule)
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \ f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if parallel_calls \
else f'{content_rule}* {tool_call_rule}?') else f'{content_rule}* {tool_call_rule}?')
self.grammar = converter.format_grammar() self.grammar = converter.format_grammar()
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
# # OR a tool-call message respecting the schema of any of the tools
# converter._add_rule(
# "root",
# converter._format_literal(prefix) + " (" +
# (response_rule or converter.not_literal("<tool_call>")) + " | " +
# converter._format_literal("<tool_call>") + " (" +
# ' | '.join(tool_rules) +
# ") " + converter._format_literal("</tool_call>") +
# ")") # + converter._format_literal(suffix))
# @typechecked
def parse(self, s: str) -> Optional[Message]: def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s) s = self.args.chat_template.strip_suffix(s)
@ -298,17 +279,10 @@ class ToolCallTagsChatHandler(ChatHandler):
content = '\n'.join(content).strip() content = '\n'.join(content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls) return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
# return None
# else:
# return Message(role="assistant", content=s)
class TemplatedToolsChatHandler(ToolCallTagsChatHandler): class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs, template: str, escapes_underscores=False, allow_parallel_calls=True): def __init__(self, args: ChatHandlerArgs, template: str, parallel_calls: bool, escapes_underscores: bool = False):
super().__init__(args, escapes_underscores=escapes_underscores, allow_parallel_calls=allow_parallel_calls) super().__init__(args, escapes_underscores=escapes_underscores, parallel_calls=parallel_calls)
assert '{tools}' in template, 'Template must contain "{tools}"' assert '{tools}' in template, 'Template must contain "{tools}"'
self.output_format_prompt = Message( self.output_format_prompt = Message(
@ -320,8 +294,8 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
) )
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler): class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool): def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args, escapes_underscores=False, allow_parallel_calls=allow_parallel_calls) super().__init__(args, escapes_underscores=False, parallel_calls=parallel_calls)
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling # Hackily import https://github.com/NousResearch/Hermes-Function-Calling
path = str(Path(__file__).parent / "hermes_function_calling") path = str(Path(__file__).parent / "hermes_function_calling")
@ -331,12 +305,12 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
except ImportError: except ImportError:
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`") raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools]) prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in args.tools])
assert len(prompt) == 1 and prompt[0]["role"] == "system" assert len(prompt) == 1 and prompt[0]["role"] == "system"
self.output_format_prompt = Message(**prompt[0]) self.output_format_prompt = Message(**prompt[0])
class FunctionaryToolsChatHandler(ChatHandler): class FunctionaryToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool): def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args) super().__init__(args)
# Only allowing a single tool call at a time for now. # Only allowing a single tool call at a time for now.
@ -355,17 +329,6 @@ class FunctionaryToolsChatHandler(ChatHandler):
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' + converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
converter._format_literal('\n')) converter._format_literal('\n'))
# converter.visit(
# dict(
# type="object",
# properties=dict(
# name=dict(const=tool.function.name),
# arguments=tool.function.parameters,
# ),
# required=['name', 'arguments']
# ),
# f'{tool.function.name}-tool-call'
# )
for i, tool in enumerate(self.args.tools) for i, tool in enumerate(self.args.tools)
] ]
@ -378,30 +341,15 @@ class FunctionaryToolsChatHandler(ChatHandler):
tool_call_without_start_rule = converter._add_rule( tool_call_without_start_rule = converter._add_rule(
'tool_call_without_start', 'tool_call_without_start',
' | '.join(tool_rules)) ' | '.join(tool_rules))
# + ' ' +
# converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}') tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
# converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
converter._add_rule( converter._add_rule(
'root', 'root',
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | ' f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \ f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if parallel_calls \
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}') else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
self.grammar = converter.format_grammar() self.grammar = converter.format_grammar()
# converter._add_rule(
# "root",
# converter._format_literal(prefix) + " (" +
# (response_rule or converter.not_literal("<|recipient|>")) + " | " +
# (' | '.join(
# converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
# converter.visit(tool.function.parameters, tool.function.name + '-args')
# for tool in tools
# )) +
# ") " +
# ")") # + converter._format_literal(suffix))
# @typechecked
def parse(self, s: str) -> Optional[Message]: def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s) s = self.args.chat_template.strip_suffix(s)
@ -433,7 +381,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
content = '\n'.join(text_content).strip() content = '\n'.join(text_content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None) return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls): def _make_bespoke_schema(response_schema, tool_call_schema, parallel_calls):
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
@ -453,7 +401,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
# "const": "tool_calls" # "const": "tool_calls"
# }, # },
"tool_calls": { "tool_calls": {
"prefixItems": tool_call_schema if allow_parallel_calls \ "prefixItems": tool_call_schema if parallel_calls \
else [tool_call_schema], else [tool_call_schema],
} }
}, },
@ -474,7 +422,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
} }
class BespokeToolsChatHandler(ChatHandler): class BespokeToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool): def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args) super().__init__(args)
# args.response_schema = args.response_schema or {} # args.response_schema = args.response_schema or {}
@ -497,7 +445,7 @@ class BespokeToolsChatHandler(ChatHandler):
for tool in self.args.tools for tool in self.args.tools
] ]
}, },
allow_parallel_calls=allow_parallel_calls, parallel_calls=parallel_calls,
), ),
'', '',
) )
@ -525,13 +473,12 @@ class BespokeToolsChatHandler(ChatHandler):
}, },
"required": ["name", "arguments"] "required": ["name", "arguments"]
}, },
allow_parallel_calls=allow_parallel_calls, parallel_calls=parallel_calls,
) )
), ),
]) ])
) )
# @typechecked
def parse(self, s: str) -> Optional[Message]: def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s) s = self.args.chat_template.strip_suffix(s)
try: try:
@ -579,19 +526,19 @@ _LONG_TEMPLATE='\n'.join([
# 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''', # 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
]) ])
def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatHandler: def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool) -> ChatHandler:
if not args.tools: if not args.tools:
return NoToolsChatHandler(args) return NoToolsChatHandler(args)
elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2: elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
return FunctionaryToolsChatHandler(args, allow_parallel_calls=False) return FunctionaryToolsChatHandler(args, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT:
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, allow_parallel_calls=allow_parallel_calls) return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, allow_parallel_calls=allow_parallel_calls) return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls) return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, parallel_calls=parallel_calls, escapes_underscores=True)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
return BespokeToolsChatHandler(args, allow_parallel_calls=allow_parallel_calls) return BespokeToolsChatHandler(args, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
return Hermes2ProToolsChatHandler(args) return Hermes2ProToolsChatHandler(args)
else: else:

View file

@ -31,7 +31,7 @@ def main(
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None, # model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
host: str = "localhost", host: str = "localhost",
port: int = 8080, port: int = 8080,
allow_parallel_calls: Optional[bool] = False, parallel_calls: Optional[bool] = True,
auth: Optional[str] = None, auth: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
context_length: Optional[int] = None, context_length: Optional[int] = None,
@ -92,7 +92,7 @@ def main(
chat_handler = get_chat_handler( chat_handler = get_chat_handler(
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools), ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
allow_parallel_calls=allow_parallel_calls parallel_calls=parallel_calls
) )
messages = chat_request.messages messages = chat_request.messages