agent: display http errors nicely

This commit is contained in:
ochafik 2024-10-24 05:40:58 +01:00
parent f5320af02a
commit 0f5d63943f

View file

@ -14,10 +14,10 @@ from functools import wraps
import json import json
from openapi import discover_tools from openapi import discover_tools
import os import os
from pydantic import BaseModel from pydantic import BaseModel, Field, Json
import sys import sys
import typer import typer
from typing import Annotated, Literal, Optional from typing import Annotated, Dict, Literal, Optional
import urllib.parse import urllib.parse
@ -80,94 +80,101 @@ async def main(
tool_map, tools = await discover_tools(tools or [], verbose) tool_map, tools = await discover_tools(tools or [], verbose)
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n') sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
try:
messages = [] messages = []
if system: if system:
messages.append(dict( messages.append(dict(
role='system', role='system',
content=system, content=system,
)) ))
messages.append( messages.append(
dict( dict(
role='user', role='user',
content=goal, content=goal,
)
)
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
async def run_turn():
for i in range(max_iterations or sys.maxsize):
url = f'{endpoint}chat/completions'
payload = dict(
messages=messages,
model=model,
tools=tools,
) )
if provider == 'llama.cpp': )
payload.update(dict(
seed=seed,
cache_prompt=cache_prompt,
)) # type: ignore
if verbose: headers = {
print(f'Calling {url} with {json.dumps(payload, indent=2)}', file=sys.stderr) 'Content-Type': 'application/json',
async with aiohttp.ClientSession(headers=headers) as session: 'Authorization': f'Bearer {api_key}'
async with session.post(url, json=payload) as response: }
response.raise_for_status() async def run_turn():
response = await response.json() for i in range(max_iterations or sys.maxsize):
if verbose: url = f'{endpoint}chat/completions'
print(f'Response: {json.dumps(response, indent=2)}', file=sys.stderr) payload = dict(
messages=messages,
model=model,
tools=tools,
)
if provider == 'llama.cpp':
payload.update(dict(
seed=seed,
cache_prompt=cache_prompt,
)) # type: ignore
assert len(response['choices']) == 1 if verbose:
choice = response['choices'][0] print(f'Calling {url} with {json.dumps(payload, indent=2)}', file=sys.stderr)
async with aiohttp.ClientSession(headers=headers) as session:
async with session.post(url, json=payload) as response:
response.raise_for_status()
response = await response.json()
if verbose:
print(f'Response: {json.dumps(response, indent=2)}', file=sys.stderr)
content = choice['message']['content'] assert len(response['choices']) == 1
if choice['finish_reason'] == 'tool_calls': choice = response['choices'][0]
messages.append(choice['message'])
assert choice['message']['tool_calls']
for tool_call in choice['message']['tool_calls']:
if content:
print(f'💭 {content}', file=sys.stderr)
name = tool_call['function']['name'] content = choice['message']['content']
args = json.loads(tool_call['function']['arguments']) if choice['finish_reason'] == 'tool_calls':
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})' messages.append(choice['message'])
print(f'⚙️ {pretty_call}', file=sys.stderr, end=None) assert choice['message']['tool_calls']
sys.stdout.flush() for tool_call in choice['message']['tool_calls']:
try: if content:
tool_result = await tool_map[name](**args) print(f'💭 {content}', file=sys.stderr)
except Exception as e:
tool_result = 'ERROR: ' + str(e)
tool_result_str = tool_result if isinstance(tool_result, str) else json.dumps(tool_result)
def describe(res, res_str, max_len = 1000):
if isinstance(res, list):
return f'{len(res)} items'
return f'{len(res_str)} chars\n {res_str[:1000] if len(res_str) > max_len else res_str}...'
print(f'{describe(tool_result, tool_result_str)}', file=sys.stderr)
if verbose:
print(tool_result_str, file=sys.stderr)
messages.append(dict(
tool_call_id=tool_call.get('id'),
role='tool',
content=tool_result_str,
))
else:
assert content
print(content)
return
if max_iterations is not None: name = tool_call['function']['name']
raise Exception(f'Failed to get a valid response after {max_iterations} tool calls') args = json.loads(tool_call['function']['arguments'])
print(f'tool_call: {json.dumps(tool_call, indent=2)}', file=sys.stderr)
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
print(f'⚙️ {pretty_call}', file=sys.stderr, end=None)
sys.stdout.flush()
try:
tool_result = await tool_map[name](**args)
except Exception as e:
tool_result = 'ERROR: ' + str(e)
tool_result_str = tool_result if isinstance(tool_result, str) else json.dumps(tool_result)
def describe(res, res_str, max_len = 1000):
if isinstance(res, list):
return f'{len(res)} items'
return f'{len(res_str)} chars\n {res_str[:1000] if len(res_str) > max_len else res_str}...'
print(f'{describe(tool_result, tool_result_str)}', file=sys.stderr)
if verbose:
print(tool_result_str, file=sys.stderr)
messages.append(dict(
tool_call_id=tool_call.get('id'),
role='tool',
content=tool_result_str,
))
else:
assert content
print(content)
return
while interactive: if max_iterations is not None:
await run_turn() raise Exception(f'Failed to get a valid response after {max_iterations} tool calls')
messages.append(dict(
role='user', while interactive:
content=input('💬 ') await run_turn()
)) messages.append(dict(
role='user',
content=input('💬 ')
))
except aiohttp.ClientResponseError as e:
sys.stdout.write(f'💥 {e}\n')
sys.exit(1)
if __name__ == '__main__': if __name__ == '__main__':