From da02397f7fd5444df3f24a96aa1b2fdf52f05d43 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 3 Oct 2024 19:18:47 +0100 Subject: [PATCH] `agent`: support more providers (+ extract serve_tools_inside_docker.sh) update readme --- examples/agent/README.md | 9 ++- examples/agent/run.py | 69 ++++++++++++++------- examples/agent/serve_tools_inside_docker.sh | 11 ++++ 3 files changed, 64 insertions(+), 25 deletions(-) create mode 100755 examples/agent/serve_tools_inside_docker.sh diff --git a/examples/agent/README.md b/examples/agent/README.md index d42fa5e36..575fdeaff 100644 --- a/examples/agent/README.md +++ b/examples/agent/README.md @@ -39,6 +39,7 @@ - Run the tools in [examples/agent/tools](./examples/agent/tools) inside a docker container (check http://localhost:8088/docs once running): ```bash + # Shorthand: ./examples/agent/serve_tools_inside_docker.sh docker run -p 8088:8088 -w /src -v $PWD/examples/agent:/src \ --env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \ --rm -it ghcr.io/astral-sh/uv:python3.12-alpine \ @@ -99,13 +100,15 @@ -- To compare the above results w/ OpenAI's tool usage behaviour, just add `--openai` to the agent invocation (other providers can easily be added, just use the `--endpoint`, `--api-key`, and `--model` flags) +- To compare the above results w/ a cloud provider's tool usage behaviour, just set the `--provider` flag (accepts `openai`, `together`, `groq`) and/or use `--endpoint`, `--api-key`, and `--model` ```bash - export OPENAI_API_KEY=... + export OPENAI_API_KEY=... # for --provider=openai + # export TOGETHER_API_KEY=... # for --provider=together + # export GROQ_API_KEY=... # for --provider=groq uv run examples/agent/run.py --tools http://localhost:8088 \ "Search for, fetch and summarize the homepage of llama.cpp" \ - --openai + --provider=openai ``` ## TODO diff --git a/examples/agent/run.py b/examples/agent/run.py index b38b183db..796d40996 100644 --- a/examples/agent/run.py +++ b/examples/agent/run.py @@ -12,12 +12,11 @@ import aiohttp import asyncio from functools import wraps import json -import logging import os from pydantic import BaseModel import sys import typer -from typing import Optional +from typing import Annotated, Literal, Optional import urllib.parse class OpenAPIMethod: @@ -103,7 +102,7 @@ class OpenAPIMethod: return response_json -async def discover_tools(tool_endpoints: list[str], logger) -> tuple[dict, list]: +async def discover_tools(tool_endpoints: list[str], verbose) -> tuple[dict, list]: tool_map = {} tools = [] @@ -119,7 +118,8 @@ async def discover_tools(tool_endpoints: list[str], logger) -> tuple[dict, list] for path, descriptor in catalog['paths'].items(): fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog) tool_map[fn.__name__] = fn - logger.debug('Function %s: params schema: %s', fn.__name__, fn.parameters_schema) + if verbose: + print(f'Function {fn.__name__}: params schema: {fn.parameters_schema}', file=sys.stderr) tools.append(dict( type='function', function=dict( @@ -142,6 +142,30 @@ def typer_async_workaround(): return wrapper return decorator + +_PROVIDERS = { + 'llama.cpp': { + 'endpoint': 'http://localhost:8080/v1/', + 'api_key_env': 'LLAMACPP_API_KEY', + }, + 'openai': { + 'endpoint': 'https://api.openai.com/v1/', + 'default_model': 'gpt-4o', + 'api_key_env': 'OPENAI_API_KEY', + }, + 'together': { + 'endpoint': 'https://api.together.xyz', + 'default_model': 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', + 'api_key_env': 'TOGETHER_API_KEY', + }, + 'groq': { + 'endpoint': 'https://api.groq.com/openai', + 'default_model': 'llama-3.1-70b-versatile', + 'api_key_env': 'GROQ_API_KEY', + }, +} + + @typer_async_workaround() async def main( goal: str, @@ -152,23 +176,17 @@ async def main( cache_prompt: bool = True, seed: Optional[int] = None, interactive: bool = True, - openai: bool = False, + provider: Annotated[str, Literal['llama.cpp', 'openai', 'together', 'groq']] = 'llama.cpp', endpoint: Optional[str] = None, api_key: Optional[str] = None, ): - logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO, format='%(message)s') - logger = logging.getLogger(__name__) - + provider_info = _PROVIDERS[provider] if endpoint is None: - if openai: - endpoint = 'https://api.openai.com/v1/' - else: - endpoint = 'http://localhost:8080/v1/' + endpoint = provider_info['endpoint'] if api_key is None: - if openai: - api_key = os.environ.get('OPENAI_API_KEY') + api_key = os.environ.get(provider_info['api_key_env']) - tool_map, tools = await discover_tools(tools or [], logger=logger) + tool_map, tools = await discover_tools(tools or [], verbose) sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else ""}\n') @@ -191,16 +209,18 @@ async def main( model=model, tools=tools, ) - if not openai: + if provider == 'llama.cpp': payload.update(dict( seed=seed, cache_prompt=cache_prompt, )) # type: ignore - logger.debug('Calling %s with %s', url, json.dumps(payload, indent=2)) + if verbose: + print(f'Calling {url} with {json.dumps(payload, indent=2)}', file=sys.stderr) async with aiohttp.ClientSession(headers=headers) as session: async with session.post(url, json=payload) as response: - logger.debug('Response: %s', response) + if verbose: + print(f'Response: {response}', file=sys.stderr) response.raise_for_status() response = await response.json() @@ -213,17 +233,22 @@ async def main( assert choice['message']['tool_calls'] for tool_call in choice['message']['tool_calls']: if content: - print(f'💭 {content}') + print(f'💭 {content}', file=sys.stderr) name = tool_call['function']['name'] args = json.loads(tool_call['function']['arguments']) pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})' - logger.info(f'⚙️ {pretty_call}') + print(f'⚙️ {pretty_call}', file=sys.stderr, end=None) sys.stdout.flush() tool_result = await tool_map[name](**args) tool_result_str = json.dumps(tool_result) - logger.info(' → %d chars', len(tool_result_str)) - logger.debug('%s', tool_result_str) + def describe(res, res_str): + if isinstance(res, list): + return f'{len(res)} items' + return f'{len(res_str)} chars' + print(f' → {describe(tool_result, tool_result_str)}', file=sys.stderr) + if verbose: + print(tool_result_str, file=sys.stderr) messages.append(dict( tool_call_id=tool_call.get('id'), role='tool', diff --git a/examples/agent/serve_tools_inside_docker.sh b/examples/agent/serve_tools_inside_docker.sh new file mode 100755 index 000000000..550587d82 --- /dev/null +++ b/examples/agent/serve_tools_inside_docker.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -euo pipefail + +PORT=${PORT:-8088} + +docker run -p $PORT:$PORT \ + -w /src \ + -v $PWD/examples/agent:/src \ + --env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \ + --rm -it ghcr.io/astral-sh/uv:python3.12-alpine \ + uv run serve_tools.py --port $PORT