From 5c414a3335f6193709db6357e2f976ef1f78af6b Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Fri, 25 Oct 2024 01:03:45 +0100 Subject: [PATCH] `agent`: simplify tools setup --- examples/agent/Dockerfile.tools | 8 ++---- examples/agent/requirements.txt | 2 +- .../{serve_tools.py => tools/__init__.py} | 25 +++++++------------ examples/agent/tools/fetch.py | 7 ------ examples/agent/tools/python.py | 4 +-- examples/agent/tools/search.py | 21 +++++----------- 6 files changed, 20 insertions(+), 47 deletions(-) rename examples/agent/{serve_tools.py => tools/__init__.py} (53%) diff --git a/examples/agent/Dockerfile.tools b/examples/agent/Dockerfile.tools index d27b64803..fb3d474e8 100644 --- a/examples/agent/Dockerfile.tools +++ b/examples/agent/Dockerfile.tools @@ -3,16 +3,12 @@ FROM python:3.12-slim RUN python -m pip install --upgrade pip && \ apt clean cache -COPY requirements.txt /root/ -# COPY . /root/ +COPY requirements.txt tools/*.py /root/ WORKDIR /root RUN pip install -r requirements.txt -COPY ./*.py /root/ -COPY ./tools/*.py /root/tools/ - COPY ./squid/ssl_cert/squidCA.crt /usr/local/share/ca-certificates/squidCA.crt RUN chmod 644 /usr/local/share/ca-certificates/squidCA.crt && update-ca-certificates ENTRYPOINT [ "uvicorn" ] -CMD ["serve_tools:app", "--host", "0.0.0.0", "--port", "8088"] \ No newline at end of file +CMD ["tools:app", "--host", "0.0.0.0", "--port", "8088"] \ No newline at end of file diff --git a/examples/agent/requirements.txt b/examples/agent/requirements.txt index cc2d414d1..8e2d735fe 100644 --- a/examples/agent/requirements.txt +++ b/examples/agent/requirements.txt @@ -1,5 +1,5 @@ aiohttp -fastapi +fastapi[standard] ipython html2text requests diff --git a/examples/agent/serve_tools.py b/examples/agent/tools/__init__.py similarity index 53% rename from examples/agent/serve_tools.py rename to examples/agent/tools/__init__.py index b20d6dcdf..56e3e9681 100644 --- a/examples/agent/serve_tools.py +++ b/examples/agent/tools/__init__.py @@ -3,16 +3,14 @@ Usage (docker isolation - with network access): - docker run -p 8088:8088 -w /src -v $PWD/examples/agent:/src \ - --env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \ - --rm -it ghcr.io/astral-sh/uv:python3.12-alpine \ - uv run serve_tools.py --port 8088 + export BRAVE_SEARCH_API_KEY=... + ./examples/agent/serve_tools_inside_docker.sh Usage (non-siloed, DANGEROUS): - uv run examples/agent/serve_tools.py --port 8088 + pip install -r examples/agent/requirements.txt + fastapi dev examples/agent/tools/__init__.py --port 8088 ''' -import asyncio import logging import re import fastapi @@ -21,15 +19,9 @@ import sys sys.path.insert(0, os.path.dirname(__file__)) -from tools.fetch import fetch_page -from tools.search import brave_search -from tools.python import python, python_tools - -# try: -# # https://github.com/aio-libs/aiohttp/discussions/6044 -# setattr(asyncio.sslproto._SSLProtocolTransport, "_start_tls_compatible", True) # type: ignore -# except Exception as e: -# print(f'Failed to patch asyncio: {e}', file=sys.stderr) +from .fetch import fetch_page +from .search import brave_search +from .python import python, python_tools_registry verbose = os.environ.get('VERBOSE', '0') == '1' include = os.environ.get('INCLUDE_TOOLS') @@ -47,6 +39,7 @@ ALL_TOOLS = { } app = fastapi.FastAPI() + for name, fn in ALL_TOOLS.items(): if include and not re.match(include, fn.__name__): continue @@ -54,4 +47,4 @@ for name, fn in ALL_TOOLS.items(): continue app.post(f'/{name}')(fn) if name != 'python': - python_tools[name] = fn + python_tools_registry[name] = fn diff --git a/examples/agent/tools/fetch.py b/examples/agent/tools/fetch.py index d1aff4887..89cd423b7 100644 --- a/examples/agent/tools/fetch.py +++ b/examples/agent/tools/fetch.py @@ -1,4 +1,3 @@ -# import aiohttp import html2text import logging import requests @@ -14,12 +13,6 @@ async def fetch_page(url: str): response = requests.get(url) response.raise_for_status() content = response.text - # async with aiohttp.ClientSession(trust_env=True) as session: - # async with session.get(url) as res: - # res.raise_for_status() - # content = await res.text() - # except aiohttp.ClientError as e: - # raise Exception(f'Failed to fetch {url}: {e}') except requests.exceptions.RequestException as e: raise Exception(f'Failed to fetch {url}: {e}') diff --git a/examples/agent/tools/python.py b/examples/agent/tools/python.py index 4dd2d9cc5..286530cf7 100644 --- a/examples/agent/tools/python.py +++ b/examples/agent/tools/python.py @@ -5,7 +5,7 @@ import logging import sys -python_tools = {} +python_tools_registry = {} def _strip_ansi_codes(text): @@ -27,7 +27,7 @@ def python(code: str) -> str: shell = InteractiveShell( colors='neutral', ) - shell.user_global_ns.update(python_tools) + shell.user_global_ns.update(python_tools_registry) old_stdout = sys.stdout sys.stdout = out = StringIO() diff --git a/examples/agent/tools/search.py b/examples/agent/tools/search.py index c36c2cbab..c89ac59c5 100644 --- a/examples/agent/tools/search.py +++ b/examples/agent/tools/search.py @@ -1,4 +1,3 @@ -# import aiohttp import itertools import json import logging @@ -52,6 +51,7 @@ async def brave_search(*, query: str) -> List[Dict]: } def extract_results(search_response): + # print("SEARCH RESPONSE: " + json.dumps(search_response, indent=2)) for m in search_response['mixed']['main']: result_type = m['type'] keys = _result_keys_by_type.get(result_type) @@ -66,19 +66,10 @@ async def brave_search(*, query: str) -> List[Dict]: for r in results_of_type: yield _extract_values(keys, r) - res = requests.get(url, headers=headers) - if not res.ok: - raise Exception(res.text) - reponse = res.json() - res.raise_for_status() - response = res.text - # async with aiohttp.ClientSession(trust_env=True) as session: - # async with session.get(url, headers=headers) as res: - # if not res.ok: - # raise Exception(await res.text()) - # res.raise_for_status() - # response = await res.json() - - results = list(itertools.islice(extract_results(response), max_results)) + response = requests.get(url, headers=headers) + if not response.ok: + raise Exception(response.text) + response.raise_for_status() + results = list(itertools.islice(extract_results(response.json()), max_results)) print(json.dumps(dict(query=query, response=response, results=results), indent=2)) return results