agent: fix wait --std-tools

This commit is contained in:
Olivier Chafik 2024-04-10 19:47:01 +01:00 committed by ochafik
parent 89dcc062a4
commit 0120f7cc95
2 changed files with 67 additions and 46 deletions

View file

@ -3,6 +3,7 @@ import sys
from time import sleep from time import sleep
import typer import typer
from pydantic import BaseModel, Json, TypeAdapter from pydantic import BaseModel, Json, TypeAdapter
from pydantic_core import SchemaValidator, core_schema
from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type
import json, requests import json, requests
@ -13,16 +14,12 @@ from examples.agent.utils import collect_functions, load_module
from examples.openai.prompting import ToolsPromptStyle from examples.openai.prompting import ToolsPromptStyle
from examples.openai.subprocesses import spawn_subprocess from examples.openai.subprocesses import spawn_subprocess
def _get_params_schema(fn: Callable[[Any], Any], verbose): def make_call_adapter(ta: TypeAdapter, fn: Callable[..., Any]):
if isinstance(fn, OpenAPIMethod): args_validator = SchemaValidator(core_schema.call_schema(
return fn.parameters_schema arguments=ta.core_schema['arguments_schema'],
function=fn,
# converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) ))
schema = TypeAdapter(fn).json_schema() return lambda **kwargs: args_validator.validate_python(kwargs)
# Do NOT call converter.resolve_refs(schema) here. Let the server resolve local refs.
if verbose:
sys.stderr.write(f'# PARAMS SCHEMA: {json.dumps(schema, indent=2)}\n')
return schema
def completion_with_tool_usage( def completion_with_tool_usage(
*, *,
@ -50,18 +47,28 @@ def completion_with_tool_usage(
schema = type_adapter.json_schema() schema = type_adapter.json_schema()
response_format=ResponseFormat(type="json_object", schema=schema) response_format=ResponseFormat(type="json_object", schema=schema)
tool_map = {fn.__name__: fn for fn in tools} tool_map = {}
tools_schemas = [ tools_schemas = []
for fn in tools:
if isinstance(fn, OpenAPIMethod):
tool_map[fn.__name__] = fn
parameters_schema = fn.parameters_schema
else:
ta = TypeAdapter(fn)
tool_map[fn.__name__] = make_call_adapter(ta, fn)
parameters_schema = ta.json_schema()
if verbose:
sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(parameters_schema, indent=2)}\n')
tools_schemas.append(
Tool( Tool(
type="function", type="function",
function=ToolFunction( function=ToolFunction(
name=fn.__name__, name=fn.__name__,
description=fn.__doc__ or '', description=fn.__doc__ or '',
parameters=_get_params_schema(fn, verbose=verbose) parameters=parameters_schema,
)
) )
) )
for fn in tools
]
i = 0 i = 0
while (max_iterations is None or i < max_iterations): while (max_iterations is None or i < max_iterations):
@ -106,7 +113,7 @@ def completion_with_tool_usage(
sys.stdout.write(f'⚙️ {pretty_call}') sys.stdout.write(f'⚙️ {pretty_call}')
sys.stdout.flush() sys.stdout.flush()
tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments) tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments)
sys.stdout.write(f" -> {tool_result}\n") sys.stdout.write(f" {tool_result}\n")
messages.append(Message( messages.append(Message(
tool_call_id=tool_call.id, tool_call_id=tool_call.id,
role="tool", role="tool",
@ -203,6 +210,8 @@ def main(
if std_tools: if std_tools:
tool_functions.extend(collect_functions(StandardTools)) tool_functions.extend(collect_functions(StandardTools))
sys.stdout.write(f'🛠️ {", ".join(fn.__name__ for fn in tool_functions)}\n')
response_model: Union[type, Json[Any]] = None #str response_model: Union[type, Json[Any]] = None #str
if format: if format:
if format in types: if format in types:

View file

@ -16,7 +16,18 @@ class Duration(BaseModel):
years: Optional[int] = None years: Optional[int] = None
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.years} years, {self.months} months, {self.days} days, {self.hours} hours, {self.minutes} minutes, {self.seconds} seconds" return ', '.join([
x
for x in [
f"{self.years} years" if self.years else None,
f"{self.months} months" if self.months else None,
f"{self.days} days" if self.days else None,
f"{self.hours} hours" if self.hours else None,
f"{self.minutes} minutes" if self.minutes else None,
f"{self.seconds} seconds" if self.seconds else None,
]
if x is not None
])
@property @property
def get_total_seconds(self) -> int: def get_total_seconds(self) -> int:
@ -36,25 +47,6 @@ class WaitForDuration(BaseModel):
sys.stderr.write(f"Waiting for {self.duration}...\n") sys.stderr.write(f"Waiting for {self.duration}...\n")
time.sleep(self.duration.get_total_seconds) time.sleep(self.duration.get_total_seconds)
class WaitForDate(BaseModel):
until: date
def __call__(self):
# Get the current date
current_date = datetime.date.today()
if self.until < current_date:
raise ValueError("Target date cannot be in the past.")
time_diff = datetime.datetime.combine(self.until, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min)
days, seconds = time_diff.days, time_diff.seconds
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {self.until}...\n")
time.sleep(days * 86400 + seconds)
sys.stderr.write(f"Reached the target date: {self.until}\n")
class StandardTools: class StandardTools:
@staticmethod @staticmethod
@ -66,12 +58,32 @@ class StandardTools:
return typer.prompt(question) return typer.prompt(question)
@staticmethod @staticmethod
def wait(_for: Union[WaitForDuration, WaitForDate]) -> None: def wait_for_duration(duration: Duration) -> None:
'Wait for a certain amount of time before continuing.'
# sys.stderr.write(f"Waiting for {duration}...\n")
time.sleep(duration.get_total_seconds)
@staticmethod
def wait_for_date(target_date: date) -> None:
f'''
Wait until a specific date is reached before continuing.
Today's date is {datetime.date.today()}
''' '''
Wait for a certain amount of time before continuing.
This can be used to wait for a specific duration or until a specific date. # Get the current date
''' current_date = datetime.date.today()
return _for()
if target_date < current_date:
raise ValueError("Target date cannot be in the past.")
time_diff = datetime.datetime.combine(target_date, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min)
days, seconds = time_diff.days, time_diff.seconds
# sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {target_date}...\n")
time.sleep(days * 86400 + seconds)
# sys.stderr.write(f"Reached the target date: {target_date}\n")
@staticmethod @staticmethod
def say_out_loud(something: str) -> None: def say_out_loud(something: str) -> None: