agent
: --openai flag (auto-fetches OPENAI_API_KEY), improved logging
This commit is contained in:
parent
2428b73853
commit
e2a9ab68a3
2 changed files with 56 additions and 25 deletions
|
@ -48,7 +48,7 @@
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> The command above gives tools (and your agent) access to the web (and read-only access to `examples/agent/**`. If you're concerned about unleashing a rogue agent on the web, please explore setting up proxies for your docker (and contribute back!)
|
> The command above gives tools (and your agent) access to the web (and read-only access to `examples/agent/**`. If you're concerned about unleashing a rogue agent on the web, please explore setting up proxies for your docker (and contribute back!)
|
||||||
|
|
||||||
- Run the agent with a given goal:
|
- Run the agent with a given goal
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run examples/agent/run.py --tools http://localhost:8088 \
|
uv run examples/agent/run.py --tools http://localhost:8088 \
|
||||||
|
@ -61,6 +61,15 @@
|
||||||
"Search for, fetch and summarize the homepage of llama.cpp"
|
"Search for, fetch and summarize the homepage of llama.cpp"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- To compare the above results w/ OpenAI's tool usage behaviour, just add `--openai` to the agent invocation (other providers can easily be added, just use the `--endpoint`, `--api-key`, and `--model` flags)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export OPENAI_API_KEY=...
|
||||||
|
uv run examples/agent/run.py --tools http://localhost:8088 \
|
||||||
|
"Search for, fetch and summarize the homepage of llama.cpp" \
|
||||||
|
--openai
|
||||||
|
```
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
- Implement code_interpreter using whichever tools are builtin for a given model.
|
- Implement code_interpreter using whichever tools are builtin for a given model.
|
||||||
|
|
|
@ -10,6 +10,8 @@
|
||||||
# ///
|
# ///
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -71,7 +73,7 @@ class OpenAPIMethod:
|
||||||
if self.body:
|
if self.body:
|
||||||
body = kwargs.pop(self.body['name'], None)
|
body = kwargs.pop(self.body['name'], None)
|
||||||
if self.body['required']:
|
if self.body['required']:
|
||||||
assert body is not None, f'Missing required body parameter: {self.body["name"]}'
|
assert body is not None, f'Missing required body parameter: {self.body['name']}'
|
||||||
else:
|
else:
|
||||||
body = None
|
body = None
|
||||||
|
|
||||||
|
@ -84,7 +86,7 @@ class OpenAPIMethod:
|
||||||
assert param['in'] == 'query', 'Only query parameters are supported'
|
assert param['in'] == 'query', 'Only query parameters are supported'
|
||||||
query_params[name] = value
|
query_params[name] = value
|
||||||
|
|
||||||
params = "&".join(f"{name}={urllib.parse.quote(str(value))}" for name, value in query_params.items() if value is not None)
|
params = '&'.join(f'{name}={urllib.parse.quote(str(value))}' for name, value in query_params.items() if value is not None)
|
||||||
url = f'{self.url}?{params}'
|
url = f'{self.url}?{params}'
|
||||||
async with session.post(url, json=body) as response:
|
async with session.post(url, json=body) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -92,7 +94,7 @@ class OpenAPIMethod:
|
||||||
|
|
||||||
return response_json
|
return response_json
|
||||||
|
|
||||||
async def discover_tools(tool_endpoints: list[str], verbose: bool = False) -> tuple[dict, list]:
|
async def discover_tools(tool_endpoints: list[str], logger) -> tuple[dict, list]:
|
||||||
tool_map = {}
|
tool_map = {}
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
|
@ -108,10 +110,9 @@ async def discover_tools(tool_endpoints: list[str], verbose: bool = False) -> tu
|
||||||
for path, descriptor in catalog['paths'].items():
|
for path, descriptor in catalog['paths'].items():
|
||||||
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
|
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
|
||||||
tool_map[fn.__name__] = fn
|
tool_map[fn.__name__] = fn
|
||||||
if verbose:
|
logger.debug('Function %s: params schema: %s', fn.__name__, fn.parameters_schema)
|
||||||
sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(fn.parameters_schema, indent=2)}\n')
|
|
||||||
tools.append(dict(
|
tools.append(dict(
|
||||||
type="function",
|
type='function',
|
||||||
function=dict(
|
function=dict(
|
||||||
name=fn.__name__,
|
name=fn.__name__,
|
||||||
description=fn.__doc__ or '',
|
description=fn.__doc__ or '',
|
||||||
|
@ -134,26 +135,41 @@ def typer_async_workaround():
|
||||||
@typer_async_workaround()
|
@typer_async_workaround()
|
||||||
async def main(
|
async def main(
|
||||||
goal: str,
|
goal: str,
|
||||||
api_key: str = '<unset>',
|
model: str = 'gpt-4o',
|
||||||
tools: Optional[list[str]] = None,
|
tools: Optional[list[str]] = None,
|
||||||
max_iterations: Optional[int] = 10,
|
max_iterations: Optional[int] = 10,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
cache_prompt: bool = True,
|
cache_prompt: bool = True,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
endpoint: str = "http://localhost:8080/v1/",
|
openai: bool = False,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
tool_map, tools = await discover_tools(tools or [], verbose)
|
logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO, format='%(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
|
if endpoint is None:
|
||||||
|
if openai:
|
||||||
|
endpoint = 'https://api.openai.com/v1/'
|
||||||
|
else:
|
||||||
|
endpoint = 'http://localhost:8080/v1/'
|
||||||
|
if api_key is None:
|
||||||
|
if openai:
|
||||||
|
api_key = os.environ.get('OPENAI_API_KEY')
|
||||||
|
|
||||||
|
tool_map, tools = await discover_tools(tools or [], logger=logger)
|
||||||
|
|
||||||
|
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
dict(
|
dict(
|
||||||
role="user",
|
role='user',
|
||||||
content=goal,
|
content=goal,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
'Authorization': f'Bearer {api_key}'
|
'Authorization': f'Bearer {api_key}'
|
||||||
}
|
}
|
||||||
async with aiohttp.ClientSession(headers=headers) as session:
|
async with aiohttp.ClientSession(headers=headers) as session:
|
||||||
|
@ -161,22 +177,26 @@ async def main(
|
||||||
url = f'{endpoint}chat/completions'
|
url = f'{endpoint}chat/completions'
|
||||||
payload = dict(
|
payload = dict(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model="gpt-4o",
|
model=model,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
seed=seed,
|
|
||||||
cache_prompt=cache_prompt,
|
|
||||||
)
|
)
|
||||||
|
if not openai:
|
||||||
|
payload.update(dict(
|
||||||
|
seed=seed,
|
||||||
|
cache_prompt=cache_prompt,
|
||||||
|
)) # type: ignore
|
||||||
|
|
||||||
|
logger.debug('Calling %s with %s', url, json.dumps(payload, indent=2))
|
||||||
async with session.post(url, json=payload) as response:
|
async with session.post(url, json=payload) as response:
|
||||||
if verbose:
|
logger.debug('Response: %s', response)
|
||||||
sys.stderr.write(f'# RESPONSE: {response}\n')
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
response = await response.json()
|
response = await response.json()
|
||||||
|
|
||||||
assert len(response["choices"]) == 1
|
assert len(response['choices']) == 1
|
||||||
choice = response["choices"][0]
|
choice = response['choices'][0]
|
||||||
|
|
||||||
content = choice['message']['content']
|
content = choice['message']['content']
|
||||||
if choice['finish_reason'] == "tool_calls":
|
if choice['finish_reason'] == 'tool_calls':
|
||||||
messages.append(choice['message'])
|
messages.append(choice['message'])
|
||||||
assert choice['message']['tool_calls']
|
assert choice['message']['tool_calls']
|
||||||
for tool_call in choice['message']['tool_calls']:
|
for tool_call in choice['message']['tool_calls']:
|
||||||
|
@ -186,14 +206,16 @@ async def main(
|
||||||
name = tool_call['function']['name']
|
name = tool_call['function']['name']
|
||||||
args = json.loads(tool_call['function']['arguments'])
|
args = json.loads(tool_call['function']['arguments'])
|
||||||
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
||||||
sys.stdout.write(f'⚙️ {pretty_call}')
|
logger.info(f'⚙️ {pretty_call}')
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
tool_result = await tool_map[name](session, **args)
|
tool_result = await tool_map[name](session, **args)
|
||||||
sys.stdout.write(f" → {tool_result}\n")
|
tool_result_str = json.dumps(tool_result)
|
||||||
|
logger.info(' → %d chars', len(tool_result_str))
|
||||||
|
logger.debug('%s', tool_result_str)
|
||||||
messages.append(dict(
|
messages.append(dict(
|
||||||
tool_call_id=tool_call.get('id'),
|
tool_call_id=tool_call.get('id'),
|
||||||
role="tool",
|
role='tool',
|
||||||
content=json.dumps(tool_result),
|
content=tool_result_str,
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
assert content
|
assert content
|
||||||
|
@ -201,7 +223,7 @@ async def main(
|
||||||
return
|
return
|
||||||
|
|
||||||
if max_iterations is not None:
|
if max_iterations is not None:
|
||||||
raise Exception(f"Failed to get a valid response after {max_iterations} tool calls")
|
raise Exception(f'Failed to get a valid response after {max_iterations} tool calls')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
typer.run(main)
|
typer.run(main)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue