From 6f2191d99e3b98ac5a925f573eb00f1e1d87ab61 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 2 Oct 2024 17:54:20 +0100 Subject: [PATCH] `agent`: remove *lots* of cruft from tool definitions derived from FastAPI catalog (and remove wait* tools which can be implemented in Python anyway) --- examples/agent/run.py | 10 +++++- examples/agent/tools/fetch.py | 18 +++------- examples/agent/tools/wait.py | 67 ----------------------------------- 3 files changed, 13 insertions(+), 82 deletions(-) delete mode 100644 examples/agent/tools/wait.py diff --git a/examples/agent/run.py b/examples/agent/run.py index 40d18622b..a897952b6 100644 --- a/examples/agent/run.py +++ b/examples/agent/run.py @@ -65,10 +65,18 @@ class OpenAPIMethod: for name, param in self.parameters.items() } }, - components=catalog.get('components'), required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else []) ) + if (components := catalog.get('components', {})) is not None: + if (schemas := components.get('schemas')) is not None: + del schemas['HTTPValidationError'] + del schemas['ValidationError'] + if not schemas: + del components['schemas'] + if components: + self.parameters_schema['components'] = components + async def __call__(self, session: aiohttp.ClientSession, **kwargs): if self.body: body = kwargs.pop(self.body['name'], None) diff --git a/examples/agent/tools/fetch.py b/examples/agent/tools/fetch.py index 19488cb35..b825c0356 100644 --- a/examples/agent/tools/fetch.py +++ b/examples/agent/tools/fetch.py @@ -1,18 +1,9 @@ import aiohttp import html2text import logging -from pydantic import BaseModel -from typing import Optional -import sys -class FetchResult(BaseModel): - content: Optional[str] = None - markdown: Optional[str] = None - error: Optional[str] = None - - -async def fetch_page(url: str) -> FetchResult: +async def fetch_page(url: str) -> str: ''' Fetch a web page (convert it to markdown if possible). ''' @@ -24,8 +15,7 @@ async def fetch_page(url: str) -> FetchResult: res.raise_for_status() content = await res.text() except aiohttp.ClientError as e: - logging.error('[fetch_page] Failed to fetch %s: %s', url, e) - return FetchResult(error=str(e)) + raise Exception(f'Failed to fetch {url}: {e}') # NOTE: Pyppeteer doesn't work great in docker, short of installing a bunch of dependencies # from pyppeteer import launch @@ -54,7 +44,7 @@ async def fetch_page(url: str) -> FetchResult: h.ignore_images = False h.ignore_emphasis = False markdown = h.handle(content) - return FetchResult(markdown=markdown) + return markdown except Exception as e: logging.warning('[fetch_page] Failed to convert HTML of %s to markdown: %s', url, e) - return FetchResult(content=content) + return content diff --git a/examples/agent/tools/wait.py b/examples/agent/tools/wait.py deleted file mode 100644 index f0d7eccc7..000000000 --- a/examples/agent/tools/wait.py +++ /dev/null @@ -1,67 +0,0 @@ -import asyncio -import datetime -import logging -from pydantic import BaseModel -from typing import Optional - -class Duration(BaseModel): - seconds: Optional[int] = None - minutes: Optional[int] = None - hours: Optional[int] = None - days: Optional[int] = None - months: Optional[int] = None - years: Optional[int] = None - - def __str__(self) -> str: - return ', '.join([ - x - for x in [ - f"{self.years} years" if self.years else None, - f"{self.months} months" if self.months else None, - f"{self.days} days" if self.days else None, - f"{self.hours} hours" if self.hours else None, - f"{self.minutes} minutes" if self.minutes else None, - f"{self.seconds} seconds" if self.seconds else None, - ] - if x is not None - ]) - - @property - def get_total_seconds(self) -> float: - return sum([ - self.seconds or 0, - (self.minutes or 0)*60, - (self.hours or 0)*3600, - (self.days or 0)*86400, - (self.months or 0)*2592000, - (self.years or 0)*31536000, - ]) - -class WaitForDuration(BaseModel): - duration: Duration - -async def wait_for_duration(duration: Duration) -> None: - ''' - Wait for a certain amount of time before continuing. - ''' - - logging.debug(f"[wait_for_duration] Waiting for %s...", duration.get_total_seconds) - await asyncio.sleep(duration.get_total_seconds) - -async def wait_for_date(target_date: datetime.date) -> None: - f''' - Wait until a specific date is reached before continuing. - Today's date is {datetime.date.today()} - ''' - - current_date = datetime.date.today() - if target_date < current_date: - raise ValueError("Target date cannot be in the past.") - - logging.debug(f"[wait_for_date] Waiting until %s (current date = %s)...", target_date, current_date) - - time_diff = datetime.datetime.combine(target_date, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min) - - days, seconds = time_diff.days, time_diff.seconds - - await asyncio.sleep(days * 86400 + seconds)