agent: add --think "tool", default to local tools endpoint, support --temperature, fix --seed

This commit is contained in:
ochafik 2024-12-05 21:32:08 +00:00
parent f9b1969097
commit adc673c355
3 changed files with 46 additions and 21 deletions

View file

@ -1,14 +1,15 @@
FROM python:3.12-slim
RUN python -m pip install --upgrade pip && \
apt update && \
apt install -y wget && \
apt clean cache
COPY requirements.txt /root/
COPY tools /root/tools
WORKDIR /root
RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu && \
pip install -r requirements.txt
COPY tools /root/tools
COPY ./squid/ssl_cert/squidCA.crt /usr/local/share/ca-certificates/squidCA.crt
RUN chmod 644 /usr/local/share/ca-certificates/squidCA.crt && update-ca-certificates

View file

@ -14,13 +14,10 @@ from functools import wraps
import json
from openapi import discover_tools
import os
from pydantic import BaseModel, Field, Json
from pydantic import BaseModel
import sys
import typer
from typing import Annotated, Dict, Literal, Optional
import urllib.parse
from typing import Annotated, Literal, Optional
def typer_async_workaround():
@ -60,19 +57,21 @@ _PROVIDERS = {
async def main(
goal: str,
model: str = 'gpt-4o',
tools: Optional[list[str]] = None,
tool_endpoints: Optional[list[str]] = None,
think: bool = False,
max_iterations: Optional[int] = 10,
system: Optional[str] = None,
verbose: bool = False,
cache_prompt: bool = True,
temperature: Optional[int] = None,
seed: Optional[int] = None,
interactive: bool = True,
provider: Annotated[str, Literal['llama.cpp', 'openai', 'together', 'groq']] = 'llama.cpp',
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
):
if not tools:
tools = ["http://localhost:8088"]
if not tool_endpoints:
tool_endpoints = ["http://localhost:8088"]
provider_info = _PROVIDERS[provider]
if endpoint is None:
@ -80,7 +79,26 @@ async def main(
if api_key is None:
api_key = os.environ.get(provider_info['api_key_env'])
tool_map, tools = await discover_tools(tools or [], verbose)
tool_map, tools = await discover_tools(tool_endpoints or [], verbose)
if think:
tools.append({
'type': 'function',
'function': {
'name': 'think',
'description': 'Call this function at every step to explain your thought process, before taking any other action',
'parameters': {
'type': 'object',
'properties': {
'thought': {
'type': 'string'
}
},
'required': ['thought']
}
}
})
tool_map['think'] = lambda thought: 'ACK'
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
@ -110,10 +128,11 @@ async def main(
messages=messages,
model=model,
tools=tools,
temperature=temperature,
seed=seed,
)
if provider == 'llama.cpp':
payload.update(dict(
seed=seed,
cache_prompt=cache_prompt,
)) # type: ignore
@ -139,20 +158,25 @@ async def main(
name = tool_call['function']['name']
args = json.loads(tool_call['function']['arguments'])
print(f'tool_call: {json.dumps(tool_call, indent=2)}', file=sys.stderr)
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
print(f'⚙️ {pretty_call}', file=sys.stderr, end=None)
sys.stdout.flush()
if verbose:
print(f'tool_call: {json.dumps(tool_call, indent=2)}', file=sys.stderr)
if think and name == 'think':
print(f'🧠 {args["thought"]}', file=sys.stderr)
else:
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
print(f'⚙️ {pretty_call}', file=sys.stderr, end=None)
sys.stderr.flush()
try:
tool_result = await tool_map[name](**args)
except Exception as e:
tool_result = 'ERROR: ' + str(e)
tool_result_str = tool_result if isinstance(tool_result, str) else json.dumps(tool_result)
def describe(res, res_str, max_len = 1000):
if isinstance(res, list):
return f'{len(res)} items'
return f'{len(res_str)} chars\n {res_str[:1000] if len(res_str) > max_len else res_str}...'
print(f'{describe(tool_result, tool_result_str)}', file=sys.stderr)
if not (think and name == 'think'):
def describe(res, res_str, max_len = 1000):
if isinstance(res, list):
return f'{len(res)} items'
return f'{len(res_str)} chars\n {res_str[:1000] if len(res_str) > max_len else res_str}...'
print(f'{describe(tool_result, tool_result_str)}', file=sys.stderr)
if verbose:
print(tool_result_str, file=sys.stderr)
messages.append(dict(

View file

@ -27,4 +27,4 @@ openssl req -new -newkey rsa:4096 -days 3650 -nodes -x509 \
openssl x509 -outform PEM -in squid/ssl_cert/squidCA.pem -out squid/ssl_cert/squidCA.crt
docker compose --verbose up --build "$@"
docker compose up --build "$@"