From 0120f7cc954e012338814c835072859b1a07fb7d Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 10 Apr 2024 19:47:01 +0100 Subject: [PATCH] agent: fix wait --std-tools --- examples/agent/agent.py | 51 ++++++++++++++----------- examples/agent/tools/std_tools.py | 62 ++++++++++++++++++------------- 2 files changed, 67 insertions(+), 46 deletions(-) diff --git a/examples/agent/agent.py b/examples/agent/agent.py index a283e0628..03fb96dca 100644 --- a/examples/agent/agent.py +++ b/examples/agent/agent.py @@ -3,6 +3,7 @@ import sys from time import sleep import typer from pydantic import BaseModel, Json, TypeAdapter +from pydantic_core import SchemaValidator, core_schema from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type import json, requests @@ -13,16 +14,12 @@ from examples.agent.utils import collect_functions, load_module from examples.openai.prompting import ToolsPromptStyle from examples.openai.subprocesses import spawn_subprocess -def _get_params_schema(fn: Callable[[Any], Any], verbose): - if isinstance(fn, OpenAPIMethod): - return fn.parameters_schema - - # converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) - schema = TypeAdapter(fn).json_schema() - # Do NOT call converter.resolve_refs(schema) here. Let the server resolve local refs. - if verbose: - sys.stderr.write(f'# PARAMS SCHEMA: {json.dumps(schema, indent=2)}\n') - return schema +def make_call_adapter(ta: TypeAdapter, fn: Callable[..., Any]): + args_validator = SchemaValidator(core_schema.call_schema( + arguments=ta.core_schema['arguments_schema'], + function=fn, + )) + return lambda **kwargs: args_validator.validate_python(kwargs) def completion_with_tool_usage( *, @@ -50,18 +47,28 @@ def completion_with_tool_usage( schema = type_adapter.json_schema() response_format=ResponseFormat(type="json_object", schema=schema) - tool_map = {fn.__name__: fn for fn in tools} - tools_schemas = [ - Tool( - type="function", - function=ToolFunction( - name=fn.__name__, - description=fn.__doc__ or '', - parameters=_get_params_schema(fn, verbose=verbose) + tool_map = {} + tools_schemas = [] + for fn in tools: + if isinstance(fn, OpenAPIMethod): + tool_map[fn.__name__] = fn + parameters_schema = fn.parameters_schema + else: + ta = TypeAdapter(fn) + tool_map[fn.__name__] = make_call_adapter(ta, fn) + parameters_schema = ta.json_schema() + if verbose: + sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(parameters_schema, indent=2)}\n') + tools_schemas.append( + Tool( + type="function", + function=ToolFunction( + name=fn.__name__, + description=fn.__doc__ or '', + parameters=parameters_schema, + ) ) ) - for fn in tools - ] i = 0 while (max_iterations is None or i < max_iterations): @@ -106,7 +113,7 @@ def completion_with_tool_usage( sys.stdout.write(f'⚙️ {pretty_call}') sys.stdout.flush() tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments) - sys.stdout.write(f" -> {tool_result}\n") + sys.stdout.write(f" → {tool_result}\n") messages.append(Message( tool_call_id=tool_call.id, role="tool", @@ -203,6 +210,8 @@ def main( if std_tools: tool_functions.extend(collect_functions(StandardTools)) + sys.stdout.write(f'🛠️ {", ".join(fn.__name__ for fn in tool_functions)}\n') + response_model: Union[type, Json[Any]] = None #str if format: if format in types: diff --git a/examples/agent/tools/std_tools.py b/examples/agent/tools/std_tools.py index 4d1e132a1..f4ee85036 100644 --- a/examples/agent/tools/std_tools.py +++ b/examples/agent/tools/std_tools.py @@ -16,7 +16,18 @@ class Duration(BaseModel): years: Optional[int] = None def __str__(self) -> str: - return f"{self.years} years, {self.months} months, {self.days} days, {self.hours} hours, {self.minutes} minutes, {self.seconds} seconds" + return ', '.join([ + x + for x in [ + f"{self.years} years" if self.years else None, + f"{self.months} months" if self.months else None, + f"{self.days} days" if self.days else None, + f"{self.hours} hours" if self.hours else None, + f"{self.minutes} minutes" if self.minutes else None, + f"{self.seconds} seconds" if self.seconds else None, + ] + if x is not None + ]) @property def get_total_seconds(self) -> int: @@ -36,25 +47,6 @@ class WaitForDuration(BaseModel): sys.stderr.write(f"Waiting for {self.duration}...\n") time.sleep(self.duration.get_total_seconds) -class WaitForDate(BaseModel): - until: date - - def __call__(self): - # Get the current date - current_date = datetime.date.today() - - if self.until < current_date: - raise ValueError("Target date cannot be in the past.") - - time_diff = datetime.datetime.combine(self.until, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min) - - days, seconds = time_diff.days, time_diff.seconds - - sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {self.until}...\n") - time.sleep(days * 86400 + seconds) - sys.stderr.write(f"Reached the target date: {self.until}\n") - - class StandardTools: @staticmethod @@ -66,12 +58,32 @@ class StandardTools: return typer.prompt(question) @staticmethod - def wait(_for: Union[WaitForDuration, WaitForDate]) -> None: + def wait_for_duration(duration: Duration) -> None: + 'Wait for a certain amount of time before continuing.' + + # sys.stderr.write(f"Waiting for {duration}...\n") + time.sleep(duration.get_total_seconds) + + @staticmethod + def wait_for_date(target_date: date) -> None: + f''' + Wait until a specific date is reached before continuing. + Today's date is {datetime.date.today()} ''' - Wait for a certain amount of time before continuing. - This can be used to wait for a specific duration or until a specific date. - ''' - return _for() + + # Get the current date + current_date = datetime.date.today() + + if target_date < current_date: + raise ValueError("Target date cannot be in the past.") + + time_diff = datetime.datetime.combine(target_date, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min) + + days, seconds = time_diff.days, time_diff.seconds + + # sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {target_date}...\n") + time.sleep(days * 86400 + seconds) + # sys.stderr.write(f"Reached the target date: {target_date}\n") @staticmethod def say_out_loud(something: str) -> None: