agent: add --think "tool", default to local tools endpoint, support --temperature, fix --seed
This commit is contained in:
parent
f9b1969097
commit
adc673c355
3 changed files with 46 additions and 21 deletions
|
@ -1,14 +1,15 @@
|
|||
FROM python:3.12-slim
|
||||
|
||||
RUN python -m pip install --upgrade pip && \
|
||||
apt update && \
|
||||
apt install -y wget && \
|
||||
apt clean cache
|
||||
|
||||
COPY requirements.txt /root/
|
||||
COPY tools /root/tools
|
||||
WORKDIR /root
|
||||
RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
pip install -r requirements.txt
|
||||
COPY tools /root/tools
|
||||
|
||||
COPY ./squid/ssl_cert/squidCA.crt /usr/local/share/ca-certificates/squidCA.crt
|
||||
RUN chmod 644 /usr/local/share/ca-certificates/squidCA.crt && update-ca-certificates
|
||||
|
|
|
@ -14,13 +14,10 @@ from functools import wraps
|
|||
import json
|
||||
from openapi import discover_tools
|
||||
import os
|
||||
from pydantic import BaseModel, Field, Json
|
||||
from pydantic import BaseModel
|
||||
import sys
|
||||
import typer
|
||||
from typing import Annotated, Dict, Literal, Optional
|
||||
import urllib.parse
|
||||
|
||||
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
|
||||
def typer_async_workaround():
|
||||
|
@ -60,19 +57,21 @@ _PROVIDERS = {
|
|||
async def main(
|
||||
goal: str,
|
||||
model: str = 'gpt-4o',
|
||||
tools: Optional[list[str]] = None,
|
||||
tool_endpoints: Optional[list[str]] = None,
|
||||
think: bool = False,
|
||||
max_iterations: Optional[int] = 10,
|
||||
system: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
cache_prompt: bool = True,
|
||||
temperature: Optional[int] = None,
|
||||
seed: Optional[int] = None,
|
||||
interactive: bool = True,
|
||||
provider: Annotated[str, Literal['llama.cpp', 'openai', 'together', 'groq']] = 'llama.cpp',
|
||||
endpoint: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
if not tools:
|
||||
tools = ["http://localhost:8088"]
|
||||
if not tool_endpoints:
|
||||
tool_endpoints = ["http://localhost:8088"]
|
||||
|
||||
provider_info = _PROVIDERS[provider]
|
||||
if endpoint is None:
|
||||
|
@ -80,7 +79,26 @@ async def main(
|
|||
if api_key is None:
|
||||
api_key = os.environ.get(provider_info['api_key_env'])
|
||||
|
||||
tool_map, tools = await discover_tools(tools or [], verbose)
|
||||
tool_map, tools = await discover_tools(tool_endpoints or [], verbose)
|
||||
|
||||
if think:
|
||||
tools.append({
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'think',
|
||||
'description': 'Call this function at every step to explain your thought process, before taking any other action',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'thought': {
|
||||
'type': 'string'
|
||||
}
|
||||
},
|
||||
'required': ['thought']
|
||||
}
|
||||
}
|
||||
})
|
||||
tool_map['think'] = lambda thought: 'ACK'
|
||||
|
||||
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
|
||||
|
||||
|
@ -110,10 +128,11 @@ async def main(
|
|||
messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
temperature=temperature,
|
||||
seed=seed,
|
||||
)
|
||||
if provider == 'llama.cpp':
|
||||
payload.update(dict(
|
||||
seed=seed,
|
||||
cache_prompt=cache_prompt,
|
||||
)) # type: ignore
|
||||
|
||||
|
@ -139,20 +158,25 @@ async def main(
|
|||
|
||||
name = tool_call['function']['name']
|
||||
args = json.loads(tool_call['function']['arguments'])
|
||||
print(f'tool_call: {json.dumps(tool_call, indent=2)}', file=sys.stderr)
|
||||
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
||||
print(f'⚙️ {pretty_call}', file=sys.stderr, end=None)
|
||||
sys.stdout.flush()
|
||||
if verbose:
|
||||
print(f'tool_call: {json.dumps(tool_call, indent=2)}', file=sys.stderr)
|
||||
if think and name == 'think':
|
||||
print(f'🧠 {args["thought"]}', file=sys.stderr)
|
||||
else:
|
||||
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
||||
print(f'⚙️ {pretty_call}', file=sys.stderr, end=None)
|
||||
sys.stderr.flush()
|
||||
try:
|
||||
tool_result = await tool_map[name](**args)
|
||||
except Exception as e:
|
||||
tool_result = 'ERROR: ' + str(e)
|
||||
tool_result_str = tool_result if isinstance(tool_result, str) else json.dumps(tool_result)
|
||||
def describe(res, res_str, max_len = 1000):
|
||||
if isinstance(res, list):
|
||||
return f'{len(res)} items'
|
||||
return f'{len(res_str)} chars\n {res_str[:1000] if len(res_str) > max_len else res_str}...'
|
||||
print(f' → {describe(tool_result, tool_result_str)}', file=sys.stderr)
|
||||
if not (think and name == 'think'):
|
||||
def describe(res, res_str, max_len = 1000):
|
||||
if isinstance(res, list):
|
||||
return f'{len(res)} items'
|
||||
return f'{len(res_str)} chars\n {res_str[:1000] if len(res_str) > max_len else res_str}...'
|
||||
print(f' → {describe(tool_result, tool_result_str)}', file=sys.stderr)
|
||||
if verbose:
|
||||
print(tool_result_str, file=sys.stderr)
|
||||
messages.append(dict(
|
||||
|
|
|
@ -27,4 +27,4 @@ openssl req -new -newkey rsa:4096 -days 3650 -nodes -x509 \
|
|||
|
||||
openssl x509 -outform PEM -in squid/ssl_cert/squidCA.pem -out squid/ssl_cert/squidCA.crt
|
||||
|
||||
docker compose --verbose up --build "$@"
|
||||
docker compose up --build "$@"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue