agent/openai:nits

This commit is contained in:
ochafik 2024-03-29 17:00:53 +00:00
parent ce2fb0155f
commit ea34bd3e5c
10 changed files with 72 additions and 145 deletions

View file

@ -128,7 +128,7 @@ def main(
max_iterations: Optional[int] = 10,
std_tools: Optional[bool] = False,
auth: Optional[str] = None,
allow_parallel_calls: Optional[bool] = False,
parallel_calls: Optional[bool] = True,
verbose: bool = False,
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
@ -174,7 +174,7 @@ def main(
"python", "-m", "examples.openai.server",
"--model", model,
*(['--verbose'] if verbose else []),
*(['--allow-parallel-calls'] if allow_parallel_calls else []),
*(['--parallel-calls'] if parallel_calls else []),
*(['--context-length={context_length}'] if context_length else []),
*([])
]

View file

@ -1,15 +1,11 @@
import atexit
from datetime import date
import datetime
from pydantic import BaseModel
import subprocess
import sys
from time import sleep
import time
import typer
from pydantic import BaseModel, Json, TypeAdapter
from annotated_types import MinLen
from typing import Annotated, Callable, List, Union, Literal, Optional, Type, get_args, get_origin
import json, requests
from typing import Union, Optional
class Duration(BaseModel):
seconds: Optional[int] = None

View file

@ -1,28 +1,12 @@
from typing import Optional
from pydantic import BaseModel, Json
from pydantic import Json
class LlamaCppServerCompletionRequest(BaseModel):
from examples.openai.api import LlamaCppParams
class LlamaCppServerCompletionRequest(LlamaCppParams):
prompt: str
stream: Optional[bool] = None
cache_prompt: Optional[bool] = None
n_predict: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
min_p: Optional[float] = None
tfs_z: Optional[float] = None
typical_p: Optional[float] = None
temperature: Optional[float] = None
dynatemp_range: Optional[float] = None
dynatemp_exponent: Optional[float] = None
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
mirostat: Optional[bool] = None
mirostat_tau: Optional[float] = None
mirostat_eta: Optional[float] = None
penalize_nl: Optional[bool] = None
n_keep: Optional[int] = None
seed: Optional[int] = None
grammar: Optional[str] = None
json_schema: Optional[Json] = None

View file

@ -1,15 +1,13 @@
from abc import ABC, abstractmethod
from enum import Enum
from functools import wraps
import jinja2
import json
from pathlib import Path
import random
import re
import sys
from typing import Any, Dict, Literal, Optional, Tuple, Callable, Union
from typing import Optional
from pydantic import BaseModel
# from typeguard import typechecked
from examples.json_schema_to_grammar import SchemaConverter
from examples.openai.api import Tool, Message, FunctionCall, ToolCall
@ -129,8 +127,6 @@ class ChatTemplate(BaseModel):
eos_token = tokenizer.eos_token)
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
# sys.stderr.write(f'# strict_user_assistant_alternation={self._strict_user_assistant_alternation}\n')
# sys.stderr.write(f'# messages=' + "\n".join(json.dumps(m.model_dump(), indent=2) for m in messages) + '\n')
if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
new_messages=[]
i = 0
@ -161,7 +157,6 @@ class ChatTemplate(BaseModel):
i += 1
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
messages = new_messages
# print(f'messages={messages}')
result = self._template.render(
messages=messages,
@ -170,7 +165,6 @@ class ChatTemplate(BaseModel):
raise_exception=raise_exception,
add_generation_prompt=add_generation_prompt,
)
# sys.stderr.write(f'\n# RENDERED:\n\n{result}\n\n')
return result
class ChatHandlerArgs(BaseModel):
@ -206,12 +200,11 @@ class NoToolsChatHandler(ChatHandler):
self.output_format_prompt = None
self.grammar = None
# @typechecked
def parse(self, s: str) -> Optional[Message]:
return Message(role="assistant", content=s)
class ToolCallTagsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, allow_parallel_calls: bool):
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, parallel_calls: bool):
super().__init__(args)
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
@ -253,22 +246,10 @@ class ToolCallTagsChatHandler(ChatHandler):
converter._add_rule(
'root',
# tool_call_rule)
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if parallel_calls \
else f'{content_rule}* {tool_call_rule}?')
self.grammar = converter.format_grammar()
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
# # OR a tool-call message respecting the schema of any of the tools
# converter._add_rule(
# "root",
# converter._format_literal(prefix) + " (" +
# (response_rule or converter.not_literal("<tool_call>")) + " | " +
# converter._format_literal("<tool_call>") + " (" +
# ' | '.join(tool_rules) +
# ") " + converter._format_literal("</tool_call>") +
# ")") # + converter._format_literal(suffix))
# @typechecked
def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s)
@ -298,17 +279,10 @@ class ToolCallTagsChatHandler(ChatHandler):
content = '\n'.join(content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
# return None
# else:
# return Message(role="assistant", content=s)
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs, template: str, escapes_underscores=False, allow_parallel_calls=True):
super().__init__(args, escapes_underscores=escapes_underscores, allow_parallel_calls=allow_parallel_calls)
def __init__(self, args: ChatHandlerArgs, template: str, parallel_calls: bool, escapes_underscores: bool = False):
super().__init__(args, escapes_underscores=escapes_underscores, parallel_calls=parallel_calls)
assert '{tools}' in template, 'Template must contain "{tools}"'
self.output_format_prompt = Message(
@ -320,8 +294,8 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
)
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
super().__init__(args, escapes_underscores=False, allow_parallel_calls=allow_parallel_calls)
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args, escapes_underscores=False, parallel_calls=parallel_calls)
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
path = str(Path(__file__).parent / "hermes_function_calling")
@ -331,12 +305,12 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
except ImportError:
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools])
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in args.tools])
assert len(prompt) == 1 and prompt[0]["role"] == "system"
self.output_format_prompt = Message(**prompt[0])
class FunctionaryToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args)
# Only allowing a single tool call at a time for now.
@ -355,17 +329,6 @@ class FunctionaryToolsChatHandler(ChatHandler):
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
converter._format_literal('\n'))
# converter.visit(
# dict(
# type="object",
# properties=dict(
# name=dict(const=tool.function.name),
# arguments=tool.function.parameters,
# ),
# required=['name', 'arguments']
# ),
# f'{tool.function.name}-tool-call'
# )
for i, tool in enumerate(self.args.tools)
]
@ -378,30 +341,15 @@ class FunctionaryToolsChatHandler(ChatHandler):
tool_call_without_start_rule = converter._add_rule(
'tool_call_without_start',
' | '.join(tool_rules))
# + ' ' +
# converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
# converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
converter._add_rule(
'root',
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if parallel_calls \
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
self.grammar = converter.format_grammar()
# converter._add_rule(
# "root",
# converter._format_literal(prefix) + " (" +
# (response_rule or converter.not_literal("<|recipient|>")) + " | " +
# (' | '.join(
# converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
# converter.visit(tool.function.parameters, tool.function.name + '-args')
# for tool in tools
# )) +
# ") " +
# ")") # + converter._format_literal(suffix))
# @typechecked
def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s)
@ -433,7 +381,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
content = '\n'.join(text_content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls):
def _make_bespoke_schema(response_schema, tool_call_schema, parallel_calls):
return {
"type": "object",
"properties": {
@ -453,7 +401,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
# "const": "tool_calls"
# },
"tool_calls": {
"prefixItems": tool_call_schema if allow_parallel_calls \
"prefixItems": tool_call_schema if parallel_calls \
else [tool_call_schema],
}
},
@ -474,7 +422,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
}
class BespokeToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args)
# args.response_schema = args.response_schema or {}
@ -497,7 +445,7 @@ class BespokeToolsChatHandler(ChatHandler):
for tool in self.args.tools
]
},
allow_parallel_calls=allow_parallel_calls,
parallel_calls=parallel_calls,
),
'',
)
@ -525,13 +473,12 @@ class BespokeToolsChatHandler(ChatHandler):
},
"required": ["name", "arguments"]
},
allow_parallel_calls=allow_parallel_calls,
parallel_calls=parallel_calls,
)
),
])
)
# @typechecked
def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s)
try:
@ -579,19 +526,19 @@ _LONG_TEMPLATE='\n'.join([
# 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
])
def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatHandler:
def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool) -> ChatHandler:
if not args.tools:
return NoToolsChatHandler(args)
elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
return FunctionaryToolsChatHandler(args, allow_parallel_calls=False)
return FunctionaryToolsChatHandler(args, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT:
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, parallel_calls=parallel_calls, escapes_underscores=True)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
return BespokeToolsChatHandler(args, allow_parallel_calls=allow_parallel_calls)
return BespokeToolsChatHandler(args, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
return Hermes2ProToolsChatHandler(args)
else:

View file

@ -31,7 +31,7 @@ def main(
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
host: str = "localhost",
port: int = 8080,
allow_parallel_calls: Optional[bool] = False,
parallel_calls: Optional[bool] = True,
auth: Optional[str] = None,
verbose: bool = False,
context_length: Optional[int] = None,
@ -92,7 +92,7 @@ def main(
chat_handler = get_chat_handler(
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
allow_parallel_calls=allow_parallel_calls
parallel_calls=parallel_calls
)
messages = chat_request.messages