agent: simplify tools setup

This commit is contained in:
Olivier Chafik 2024-10-25 01:03:45 +01:00
parent 0f4fc8cb28
commit 5c414a3335
6 changed files with 20 additions and 47 deletions

View file

@ -3,16 +3,12 @@ FROM python:3.12-slim
RUN python -m pip install --upgrade pip && \ RUN python -m pip install --upgrade pip && \
apt clean cache apt clean cache
COPY requirements.txt /root/ COPY requirements.txt tools/*.py /root/
# COPY . /root/
WORKDIR /root WORKDIR /root
RUN pip install -r requirements.txt RUN pip install -r requirements.txt
COPY ./*.py /root/
COPY ./tools/*.py /root/tools/
COPY ./squid/ssl_cert/squidCA.crt /usr/local/share/ca-certificates/squidCA.crt COPY ./squid/ssl_cert/squidCA.crt /usr/local/share/ca-certificates/squidCA.crt
RUN chmod 644 /usr/local/share/ca-certificates/squidCA.crt && update-ca-certificates RUN chmod 644 /usr/local/share/ca-certificates/squidCA.crt && update-ca-certificates
ENTRYPOINT [ "uvicorn" ] ENTRYPOINT [ "uvicorn" ]
CMD ["serve_tools:app", "--host", "0.0.0.0", "--port", "8088"] CMD ["tools:app", "--host", "0.0.0.0", "--port", "8088"]

View file

@ -1,5 +1,5 @@
aiohttp aiohttp
fastapi fastapi[standard]
ipython ipython
html2text html2text
requests requests

View file

@ -3,16 +3,14 @@
Usage (docker isolation - with network access): Usage (docker isolation - with network access):
docker run -p 8088:8088 -w /src -v $PWD/examples/agent:/src \ export BRAVE_SEARCH_API_KEY=...
--env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \ ./examples/agent/serve_tools_inside_docker.sh
--rm -it ghcr.io/astral-sh/uv:python3.12-alpine \
uv run serve_tools.py --port 8088
Usage (non-siloed, DANGEROUS): Usage (non-siloed, DANGEROUS):
uv run examples/agent/serve_tools.py --port 8088 pip install -r examples/agent/requirements.txt
fastapi dev examples/agent/tools/__init__.py --port 8088
''' '''
import asyncio
import logging import logging
import re import re
import fastapi import fastapi
@ -21,15 +19,9 @@ import sys
sys.path.insert(0, os.path.dirname(__file__)) sys.path.insert(0, os.path.dirname(__file__))
from tools.fetch import fetch_page from .fetch import fetch_page
from tools.search import brave_search from .search import brave_search
from tools.python import python, python_tools from .python import python, python_tools_registry
# try:
# # https://github.com/aio-libs/aiohttp/discussions/6044
# setattr(asyncio.sslproto._SSLProtocolTransport, "_start_tls_compatible", True) # type: ignore
# except Exception as e:
# print(f'Failed to patch asyncio: {e}', file=sys.stderr)
verbose = os.environ.get('VERBOSE', '0') == '1' verbose = os.environ.get('VERBOSE', '0') == '1'
include = os.environ.get('INCLUDE_TOOLS') include = os.environ.get('INCLUDE_TOOLS')
@ -47,6 +39,7 @@ ALL_TOOLS = {
} }
app = fastapi.FastAPI() app = fastapi.FastAPI()
for name, fn in ALL_TOOLS.items(): for name, fn in ALL_TOOLS.items():
if include and not re.match(include, fn.__name__): if include and not re.match(include, fn.__name__):
continue continue
@ -54,4 +47,4 @@ for name, fn in ALL_TOOLS.items():
continue continue
app.post(f'/{name}')(fn) app.post(f'/{name}')(fn)
if name != 'python': if name != 'python':
python_tools[name] = fn python_tools_registry[name] = fn

View file

@ -1,4 +1,3 @@
# import aiohttp
import html2text import html2text
import logging import logging
import requests import requests
@ -14,12 +13,6 @@ async def fetch_page(url: str):
response = requests.get(url) response = requests.get(url)
response.raise_for_status() response.raise_for_status()
content = response.text content = response.text
# async with aiohttp.ClientSession(trust_env=True) as session:
# async with session.get(url) as res:
# res.raise_for_status()
# content = await res.text()
# except aiohttp.ClientError as e:
# raise Exception(f'Failed to fetch {url}: {e}')
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
raise Exception(f'Failed to fetch {url}: {e}') raise Exception(f'Failed to fetch {url}: {e}')

View file

@ -5,7 +5,7 @@ import logging
import sys import sys
python_tools = {} python_tools_registry = {}
def _strip_ansi_codes(text): def _strip_ansi_codes(text):
@ -27,7 +27,7 @@ def python(code: str) -> str:
shell = InteractiveShell( shell = InteractiveShell(
colors='neutral', colors='neutral',
) )
shell.user_global_ns.update(python_tools) shell.user_global_ns.update(python_tools_registry)
old_stdout = sys.stdout old_stdout = sys.stdout
sys.stdout = out = StringIO() sys.stdout = out = StringIO()

View file

@ -1,4 +1,3 @@
# import aiohttp
import itertools import itertools
import json import json
import logging import logging
@ -52,6 +51,7 @@ async def brave_search(*, query: str) -> List[Dict]:
} }
def extract_results(search_response): def extract_results(search_response):
# print("SEARCH RESPONSE: " + json.dumps(search_response, indent=2))
for m in search_response['mixed']['main']: for m in search_response['mixed']['main']:
result_type = m['type'] result_type = m['type']
keys = _result_keys_by_type.get(result_type) keys = _result_keys_by_type.get(result_type)
@ -66,19 +66,10 @@ async def brave_search(*, query: str) -> List[Dict]:
for r in results_of_type: for r in results_of_type:
yield _extract_values(keys, r) yield _extract_values(keys, r)
res = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
if not res.ok: if not response.ok:
raise Exception(res.text) raise Exception(response.text)
reponse = res.json() response.raise_for_status()
res.raise_for_status() results = list(itertools.islice(extract_results(response.json()), max_results))
response = res.text
# async with aiohttp.ClientSession(trust_env=True) as session:
# async with session.get(url, headers=headers) as res:
# if not res.ok:
# raise Exception(await res.text())
# res.raise_for_status()
# response = await res.json()
results = list(itertools.islice(extract_results(response), max_results))
print(json.dumps(dict(query=query, response=response, results=results), indent=2)) print(json.dumps(dict(query=query, response=response, results=results), indent=2))
return results return results