agent
: display http errors nicely
This commit is contained in:
parent
f5320af02a
commit
0f5d63943f
1 changed files with 89 additions and 82 deletions
|
@ -14,10 +14,10 @@ from functools import wraps
|
||||||
import json
|
import json
|
||||||
from openapi import discover_tools
|
from openapi import discover_tools
|
||||||
import os
|
import os
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field, Json
|
||||||
import sys
|
import sys
|
||||||
import typer
|
import typer
|
||||||
from typing import Annotated, Literal, Optional
|
from typing import Annotated, Dict, Literal, Optional
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,94 +80,101 @@ async def main(
|
||||||
tool_map, tools = await discover_tools(tools or [], verbose)
|
tool_map, tools = await discover_tools(tools or [], verbose)
|
||||||
|
|
||||||
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
|
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
if system:
|
if system:
|
||||||
messages.append(dict(
|
messages.append(dict(
|
||||||
role='system',
|
role='system',
|
||||||
content=system,
|
content=system,
|
||||||
))
|
))
|
||||||
messages.append(
|
messages.append(
|
||||||
dict(
|
dict(
|
||||||
role='user',
|
role='user',
|
||||||
content=goal,
|
content=goal,
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': f'Bearer {api_key}'
|
|
||||||
}
|
|
||||||
async def run_turn():
|
|
||||||
for i in range(max_iterations or sys.maxsize):
|
|
||||||
url = f'{endpoint}chat/completions'
|
|
||||||
payload = dict(
|
|
||||||
messages=messages,
|
|
||||||
model=model,
|
|
||||||
tools=tools,
|
|
||||||
)
|
)
|
||||||
if provider == 'llama.cpp':
|
)
|
||||||
payload.update(dict(
|
|
||||||
seed=seed,
|
|
||||||
cache_prompt=cache_prompt,
|
|
||||||
)) # type: ignore
|
|
||||||
|
|
||||||
if verbose:
|
headers = {
|
||||||
print(f'Calling {url} with {json.dumps(payload, indent=2)}', file=sys.stderr)
|
'Content-Type': 'application/json',
|
||||||
async with aiohttp.ClientSession(headers=headers) as session:
|
'Authorization': f'Bearer {api_key}'
|
||||||
async with session.post(url, json=payload) as response:
|
}
|
||||||
response.raise_for_status()
|
async def run_turn():
|
||||||
response = await response.json()
|
for i in range(max_iterations or sys.maxsize):
|
||||||
if verbose:
|
url = f'{endpoint}chat/completions'
|
||||||
print(f'Response: {json.dumps(response, indent=2)}', file=sys.stderr)
|
payload = dict(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
if provider == 'llama.cpp':
|
||||||
|
payload.update(dict(
|
||||||
|
seed=seed,
|
||||||
|
cache_prompt=cache_prompt,
|
||||||
|
)) # type: ignore
|
||||||
|
|
||||||
assert len(response['choices']) == 1
|
if verbose:
|
||||||
choice = response['choices'][0]
|
print(f'Calling {url} with {json.dumps(payload, indent=2)}', file=sys.stderr)
|
||||||
|
async with aiohttp.ClientSession(headers=headers) as session:
|
||||||
|
async with session.post(url, json=payload) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
response = await response.json()
|
||||||
|
if verbose:
|
||||||
|
print(f'Response: {json.dumps(response, indent=2)}', file=sys.stderr)
|
||||||
|
|
||||||
content = choice['message']['content']
|
assert len(response['choices']) == 1
|
||||||
if choice['finish_reason'] == 'tool_calls':
|
choice = response['choices'][0]
|
||||||
messages.append(choice['message'])
|
|
||||||
assert choice['message']['tool_calls']
|
|
||||||
for tool_call in choice['message']['tool_calls']:
|
|
||||||
if content:
|
|
||||||
print(f'💭 {content}', file=sys.stderr)
|
|
||||||
|
|
||||||
name = tool_call['function']['name']
|
content = choice['message']['content']
|
||||||
args = json.loads(tool_call['function']['arguments'])
|
if choice['finish_reason'] == 'tool_calls':
|
||||||
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
messages.append(choice['message'])
|
||||||
print(f'⚙️ {pretty_call}', file=sys.stderr, end=None)
|
assert choice['message']['tool_calls']
|
||||||
sys.stdout.flush()
|
for tool_call in choice['message']['tool_calls']:
|
||||||
try:
|
if content:
|
||||||
tool_result = await tool_map[name](**args)
|
print(f'💭 {content}', file=sys.stderr)
|
||||||
except Exception as e:
|
|
||||||
tool_result = 'ERROR: ' + str(e)
|
|
||||||
tool_result_str = tool_result if isinstance(tool_result, str) else json.dumps(tool_result)
|
|
||||||
def describe(res, res_str, max_len = 1000):
|
|
||||||
if isinstance(res, list):
|
|
||||||
return f'{len(res)} items'
|
|
||||||
return f'{len(res_str)} chars\n {res_str[:1000] if len(res_str) > max_len else res_str}...'
|
|
||||||
print(f' → {describe(tool_result, tool_result_str)}', file=sys.stderr)
|
|
||||||
if verbose:
|
|
||||||
print(tool_result_str, file=sys.stderr)
|
|
||||||
messages.append(dict(
|
|
||||||
tool_call_id=tool_call.get('id'),
|
|
||||||
role='tool',
|
|
||||||
content=tool_result_str,
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
assert content
|
|
||||||
print(content)
|
|
||||||
return
|
|
||||||
|
|
||||||
if max_iterations is not None:
|
name = tool_call['function']['name']
|
||||||
raise Exception(f'Failed to get a valid response after {max_iterations} tool calls')
|
args = json.loads(tool_call['function']['arguments'])
|
||||||
|
print(f'tool_call: {json.dumps(tool_call, indent=2)}', file=sys.stderr)
|
||||||
|
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
||||||
|
print(f'⚙️ {pretty_call}', file=sys.stderr, end=None)
|
||||||
|
sys.stdout.flush()
|
||||||
|
try:
|
||||||
|
tool_result = await tool_map[name](**args)
|
||||||
|
except Exception as e:
|
||||||
|
tool_result = 'ERROR: ' + str(e)
|
||||||
|
tool_result_str = tool_result if isinstance(tool_result, str) else json.dumps(tool_result)
|
||||||
|
def describe(res, res_str, max_len = 1000):
|
||||||
|
if isinstance(res, list):
|
||||||
|
return f'{len(res)} items'
|
||||||
|
return f'{len(res_str)} chars\n {res_str[:1000] if len(res_str) > max_len else res_str}...'
|
||||||
|
print(f' → {describe(tool_result, tool_result_str)}', file=sys.stderr)
|
||||||
|
if verbose:
|
||||||
|
print(tool_result_str, file=sys.stderr)
|
||||||
|
messages.append(dict(
|
||||||
|
tool_call_id=tool_call.get('id'),
|
||||||
|
role='tool',
|
||||||
|
content=tool_result_str,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
assert content
|
||||||
|
print(content)
|
||||||
|
return
|
||||||
|
|
||||||
while interactive:
|
if max_iterations is not None:
|
||||||
await run_turn()
|
raise Exception(f'Failed to get a valid response after {max_iterations} tool calls')
|
||||||
messages.append(dict(
|
|
||||||
role='user',
|
while interactive:
|
||||||
content=input('💬 ')
|
await run_turn()
|
||||||
))
|
messages.append(dict(
|
||||||
|
role='user',
|
||||||
|
content=input('💬 ')
|
||||||
|
))
|
||||||
|
|
||||||
|
except aiohttp.ClientResponseError as e:
|
||||||
|
sys.stdout.write(f'💥 {e}\n')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue