tool-call: make agent async

This commit is contained in:
ochafik 2024-09-28 19:11:09 +01:00
parent 05bbba9f8a
commit ef2a020276
3 changed files with 92 additions and 83 deletions

View file

@ -1,29 +1,30 @@
# /// script # /// script
# requires-python = ">=3.11" # requires-python = ">=3.11"
# dependencies = [ # dependencies = [
# "aiohttp",
# "fastapi", # "fastapi",
# "openai", # "openai",
# "pydantic", # "pydantic",
# "requests",
# "uvicorn",
# "typer", # "typer",
# "uvicorn",
# ] # ]
# /// # ///
import json import json
import openai import asyncio
import aiohttp
from functools import wraps
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
from pydantic import BaseModel from pydantic import BaseModel
import requests
import sys import sys
import typer import typer
from typing import Annotated, Optional from typing import Annotated, Optional
import urllib.parse import urllib.parse
class OpenAPIMethod: class OpenAPIMethod:
def __init__(self, url, name, descriptor, catalog): def __init__(self, url, name, descriptor, catalog):
''' '''
Wraps a remote OpenAPI method as a Python function. Wraps a remote OpenAPI method as an async Python function.
''' '''
self.url = url self.url = url
self.__name__ = name self.__name__ = name
@ -69,7 +70,7 @@ class OpenAPIMethod:
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else []) required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
) )
def __call__(self, **kwargs): async def __call__(self, session: aiohttp.ClientSession, **kwargs):
if self.body: if self.body:
body = kwargs.pop(self.body['name'], None) body = kwargs.pop(self.body['name'], None)
if self.body['required']: if self.body['required']:
@ -86,16 +87,55 @@ class OpenAPIMethod:
assert param['in'] == 'query', 'Only query parameters are supported' assert param['in'] == 'query', 'Only query parameters are supported'
query_params[name] = value query_params[name] = value
params = "&".join(f"{name}={urllib.parse.quote(value)}" for name, value in query_params.items()) params = "&".join(f"{name}={urllib.parse.quote(str(value))}" for name, value in query_params.items() if value is not None)
url = f'{self.url}?{params}' url = f'{self.url}?{params}'
response = requests.post(url, json=body) async with session.post(url, json=body) as response:
response.raise_for_status() response.raise_for_status()
response_json = response.json() response_json = await response.json()
return response_json return response_json
async def discover_tools(tool_endpoints: list[str], verbose: bool = False) -> tuple[dict, list]:
tool_map = {}
tools = []
def main( async with aiohttp.ClientSession() as session:
for url in tool_endpoints:
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'
catalog_url = f'{url}/openapi.json'
async with session.get(catalog_url) as response:
response.raise_for_status()
catalog = await response.json()
for path, descriptor in catalog['paths'].items():
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
tool_map[fn.__name__] = fn
if verbose:
sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(fn.parameters_schema, indent=2)}\n')
tools.append(dict(
type="function",
function=dict(
name=fn.__name__,
description=fn.__doc__ or '',
parameters=fn.parameters_schema,
)
)
)
return tool_map, tools
def typer_async_workaround():
'Adapted from https://github.com/fastapi/typer/issues/950#issuecomment-2351076467'
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
return asyncio.run(f(*args, **kwargs))
return wrapper
return decorator
@typer_async_workaround()
async def main(
goal: Annotated[str, typer.Option()], goal: Annotated[str, typer.Option()],
api_key: str = '<unset>', api_key: str = '<unset>',
tool_endpoint: Optional[list[str]] = None, tool_endpoint: Optional[list[str]] = None,
@ -103,36 +143,9 @@ def main(
verbose: bool = False, verbose: bool = False,
endpoint: str = "http://localhost:8080/v1/", endpoint: str = "http://localhost:8080/v1/",
): ):
client = AsyncOpenAI(api_key=api_key, base_url=endpoint)
openai.api_key = api_key tool_map, tools = await discover_tools(tool_endpoint or [], verbose)
openai.base_url = endpoint
tool_map = {}
tools = []
# Discover tools using OpenAPI catalogs at the provided endpoints.
for url in (tool_endpoint or []):
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'
catalog_url = f'{url}/openapi.json'
catalog_response = requests.get(catalog_url)
catalog_response.raise_for_status()
catalog = catalog_response.json()
for path, descriptor in catalog['paths'].items():
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
tool_map[fn.__name__] = fn
if verbose:
sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(fn.parameters_schema, indent=2)}\n')
tools.append(dict(
type="function",
function=dict(
name=fn.__name__,
description=fn.__doc__ or '',
parameters=fn.parameters_schema,
)
)
)
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n') sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
@ -143,51 +156,46 @@ def main(
) )
] ]
i = 0 async with aiohttp.ClientSession() as session:
while (max_iterations is None or i < max_iterations): for i in range(max_iterations or sys.maxsize):
response = await client.chat.completions.create(
model="gpt-4o",
messages=messages,
tools=tools,
)
response = openai.chat.completions.create( if verbose:
model="gpt-4o", sys.stderr.write(f'# RESPONSE: {response}\n')
messages=messages,
tools=tools,
)
if verbose: assert len(response.choices) == 1
sys.stderr.write(f'# RESPONSE: {response}\n') choice = response.choices[0]
assert len(response.choices) == 1 content = choice.message.content
choice = response.choices[0] if choice.finish_reason == "tool_calls":
messages.append(choice.message) # type: ignore
assert choice.message.tool_calls
for tool_call in choice.message.tool_calls:
if content:
print(f'💭 {content}')
content = choice.message.content args = json.loads(tool_call.function.arguments)
if choice.finish_reason == "tool_calls": pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
messages.append(choice.message) # type: ignore sys.stdout.write(f'⚙️ {pretty_call}')
assert choice.message.tool_calls sys.stdout.flush()
for tool_call in choice.message.tool_calls: tool_result = await tool_map[tool_call.function.name](session, **args)
if content: sys.stdout.write(f"{tool_result}\n")
print(f'💭 {content}') messages.append(ChatCompletionToolMessageParam(
tool_call_id=tool_call.id,
role="tool",
content=json.dumps(tool_result),
))
else:
assert content
print(content)
return
args = json.loads(tool_call.function.arguments) if max_iterations is not None:
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})' raise Exception(f"Failed to get a valid response after {max_iterations} tool calls")
sys.stdout.write(f'⚙️ {pretty_call}')
sys.stdout.flush()
tool_result = tool_map[tool_call.function.name](**args)
sys.stdout.write(f"{tool_result}\n")
messages.append(ChatCompletionToolMessageParam(
tool_call_id=tool_call.id,
role="tool",
# name=tool_call.function.name,
content=json.dumps(tool_result),
# content=f'{pretty_call} = {tool_result}',
))
else:
assert content
print(content)
return
i += 1
if max_iterations is not None:
raise Exception(f"Failed to get a valid response after {max_iterations} tool calls")
if __name__ == '__main__': if __name__ == '__main__':
typer.run(main) typer.run(main)

View file

@ -89,7 +89,7 @@ def python(code: str) -> str:
Returns: Returns:
str: The output of the executed code. str: The output of the executed code.
""" """
from IPython import InteractiveShell from IPython.core.interactiveshell import InteractiveShell
from io import StringIO from io import StringIO
import sys import sys

View file

@ -1,6 +1,7 @@
aiohttp
fastapi fastapi
ipython
openai openai
pydantic pydantic
requests
typer typer
uvicorn uvicorn