agent
: support more providers (+ extract serve_tools_inside_docker.sh)
update readme
This commit is contained in:
parent
b4fc1e8ba7
commit
da02397f7f
3 changed files with 64 additions and 25 deletions
|
@ -39,6 +39,7 @@
|
||||||
- Run the tools in [examples/agent/tools](./examples/agent/tools) inside a docker container (check http://localhost:8088/docs once running):
|
- Run the tools in [examples/agent/tools](./examples/agent/tools) inside a docker container (check http://localhost:8088/docs once running):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Shorthand: ./examples/agent/serve_tools_inside_docker.sh
|
||||||
docker run -p 8088:8088 -w /src -v $PWD/examples/agent:/src \
|
docker run -p 8088:8088 -w /src -v $PWD/examples/agent:/src \
|
||||||
--env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \
|
--env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \
|
||||||
--rm -it ghcr.io/astral-sh/uv:python3.12-alpine \
|
--rm -it ghcr.io/astral-sh/uv:python3.12-alpine \
|
||||||
|
@ -99,13 +100,15 @@
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
- To compare the above results w/ OpenAI's tool usage behaviour, just add `--openai` to the agent invocation (other providers can easily be added, just use the `--endpoint`, `--api-key`, and `--model` flags)
|
- To compare the above results w/ a cloud provider's tool usage behaviour, just set the `--provider` flag (accepts `openai`, `together`, `groq`) and/or use `--endpoint`, `--api-key`, and `--model`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export OPENAI_API_KEY=...
|
export OPENAI_API_KEY=... # for --provider=openai
|
||||||
|
# export TOGETHER_API_KEY=... # for --provider=together
|
||||||
|
# export GROQ_API_KEY=... # for --provider=groq
|
||||||
uv run examples/agent/run.py --tools http://localhost:8088 \
|
uv run examples/agent/run.py --tools http://localhost:8088 \
|
||||||
"Search for, fetch and summarize the homepage of llama.cpp" \
|
"Search for, fetch and summarize the homepage of llama.cpp" \
|
||||||
--openai
|
--provider=openai
|
||||||
```
|
```
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
|
@ -12,12 +12,11 @@ import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import sys
|
import sys
|
||||||
import typer
|
import typer
|
||||||
from typing import Optional
|
from typing import Annotated, Literal, Optional
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
class OpenAPIMethod:
|
class OpenAPIMethod:
|
||||||
|
@ -103,7 +102,7 @@ class OpenAPIMethod:
|
||||||
|
|
||||||
return response_json
|
return response_json
|
||||||
|
|
||||||
async def discover_tools(tool_endpoints: list[str], logger) -> tuple[dict, list]:
|
async def discover_tools(tool_endpoints: list[str], verbose) -> tuple[dict, list]:
|
||||||
tool_map = {}
|
tool_map = {}
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
|
@ -119,7 +118,8 @@ async def discover_tools(tool_endpoints: list[str], logger) -> tuple[dict, list]
|
||||||
for path, descriptor in catalog['paths'].items():
|
for path, descriptor in catalog['paths'].items():
|
||||||
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
|
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
|
||||||
tool_map[fn.__name__] = fn
|
tool_map[fn.__name__] = fn
|
||||||
logger.debug('Function %s: params schema: %s', fn.__name__, fn.parameters_schema)
|
if verbose:
|
||||||
|
print(f'Function {fn.__name__}: params schema: {fn.parameters_schema}', file=sys.stderr)
|
||||||
tools.append(dict(
|
tools.append(dict(
|
||||||
type='function',
|
type='function',
|
||||||
function=dict(
|
function=dict(
|
||||||
|
@ -142,6 +142,30 @@ def typer_async_workaround():
|
||||||
return wrapper
|
return wrapper
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
_PROVIDERS = {
|
||||||
|
'llama.cpp': {
|
||||||
|
'endpoint': 'http://localhost:8080/v1/',
|
||||||
|
'api_key_env': 'LLAMACPP_API_KEY',
|
||||||
|
},
|
||||||
|
'openai': {
|
||||||
|
'endpoint': 'https://api.openai.com/v1/',
|
||||||
|
'default_model': 'gpt-4o',
|
||||||
|
'api_key_env': 'OPENAI_API_KEY',
|
||||||
|
},
|
||||||
|
'together': {
|
||||||
|
'endpoint': 'https://api.together.xyz',
|
||||||
|
'default_model': 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo',
|
||||||
|
'api_key_env': 'TOGETHER_API_KEY',
|
||||||
|
},
|
||||||
|
'groq': {
|
||||||
|
'endpoint': 'https://api.groq.com/openai',
|
||||||
|
'default_model': 'llama-3.1-70b-versatile',
|
||||||
|
'api_key_env': 'GROQ_API_KEY',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@typer_async_workaround()
|
@typer_async_workaround()
|
||||||
async def main(
|
async def main(
|
||||||
goal: str,
|
goal: str,
|
||||||
|
@ -152,23 +176,17 @@ async def main(
|
||||||
cache_prompt: bool = True,
|
cache_prompt: bool = True,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
interactive: bool = True,
|
interactive: bool = True,
|
||||||
openai: bool = False,
|
provider: Annotated[str, Literal['llama.cpp', 'openai', 'together', 'groq']] = 'llama.cpp',
|
||||||
endpoint: Optional[str] = None,
|
endpoint: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO, format='%(message)s')
|
provider_info = _PROVIDERS[provider]
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
if endpoint is None:
|
if endpoint is None:
|
||||||
if openai:
|
endpoint = provider_info['endpoint']
|
||||||
endpoint = 'https://api.openai.com/v1/'
|
|
||||||
else:
|
|
||||||
endpoint = 'http://localhost:8080/v1/'
|
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
if openai:
|
api_key = os.environ.get(provider_info['api_key_env'])
|
||||||
api_key = os.environ.get('OPENAI_API_KEY')
|
|
||||||
|
|
||||||
tool_map, tools = await discover_tools(tools or [], logger=logger)
|
tool_map, tools = await discover_tools(tools or [], verbose)
|
||||||
|
|
||||||
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
|
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
|
||||||
|
|
||||||
|
@ -191,16 +209,18 @@ async def main(
|
||||||
model=model,
|
model=model,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
if not openai:
|
if provider == 'llama.cpp':
|
||||||
payload.update(dict(
|
payload.update(dict(
|
||||||
seed=seed,
|
seed=seed,
|
||||||
cache_prompt=cache_prompt,
|
cache_prompt=cache_prompt,
|
||||||
)) # type: ignore
|
)) # type: ignore
|
||||||
|
|
||||||
logger.debug('Calling %s with %s', url, json.dumps(payload, indent=2))
|
if verbose:
|
||||||
|
print(f'Calling {url} with {json.dumps(payload, indent=2)}', file=sys.stderr)
|
||||||
async with aiohttp.ClientSession(headers=headers) as session:
|
async with aiohttp.ClientSession(headers=headers) as session:
|
||||||
async with session.post(url, json=payload) as response:
|
async with session.post(url, json=payload) as response:
|
||||||
logger.debug('Response: %s', response)
|
if verbose:
|
||||||
|
print(f'Response: {response}', file=sys.stderr)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
response = await response.json()
|
response = await response.json()
|
||||||
|
|
||||||
|
@ -213,17 +233,22 @@ async def main(
|
||||||
assert choice['message']['tool_calls']
|
assert choice['message']['tool_calls']
|
||||||
for tool_call in choice['message']['tool_calls']:
|
for tool_call in choice['message']['tool_calls']:
|
||||||
if content:
|
if content:
|
||||||
print(f'💭 {content}')
|
print(f'💭 {content}', file=sys.stderr)
|
||||||
|
|
||||||
name = tool_call['function']['name']
|
name = tool_call['function']['name']
|
||||||
args = json.loads(tool_call['function']['arguments'])
|
args = json.loads(tool_call['function']['arguments'])
|
||||||
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
||||||
logger.info(f'⚙️ {pretty_call}')
|
print(f'⚙️ {pretty_call}', file=sys.stderr, end=None)
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
tool_result = await tool_map[name](**args)
|
tool_result = await tool_map[name](**args)
|
||||||
tool_result_str = json.dumps(tool_result)
|
tool_result_str = json.dumps(tool_result)
|
||||||
logger.info(' → %d chars', len(tool_result_str))
|
def describe(res, res_str):
|
||||||
logger.debug('%s', tool_result_str)
|
if isinstance(res, list):
|
||||||
|
return f'{len(res)} items'
|
||||||
|
return f'{len(res_str)} chars'
|
||||||
|
print(f' → {describe(tool_result, tool_result_str)}', file=sys.stderr)
|
||||||
|
if verbose:
|
||||||
|
print(tool_result_str, file=sys.stderr)
|
||||||
messages.append(dict(
|
messages.append(dict(
|
||||||
tool_call_id=tool_call.get('id'),
|
tool_call_id=tool_call.get('id'),
|
||||||
role='tool',
|
role='tool',
|
||||||
|
|
11
examples/agent/serve_tools_inside_docker.sh
Executable file
11
examples/agent/serve_tools_inside_docker.sh
Executable file
|
@ -0,0 +1,11 @@
|
||||||
|
#!/bin/bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
PORT=${PORT:-8088}
|
||||||
|
|
||||||
|
docker run -p $PORT:$PORT \
|
||||||
|
-w /src \
|
||||||
|
-v $PWD/examples/agent:/src \
|
||||||
|
--env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \
|
||||||
|
--rm -it ghcr.io/astral-sh/uv:python3.12-alpine \
|
||||||
|
uv run serve_tools.py --port $PORT
|
Loading…
Add table
Add a link
Reference in a new issue