agent: fix wait --std-tools
This commit is contained in:
parent
89dcc062a4
commit
0120f7cc95
2 changed files with 67 additions and 46 deletions
|
@ -3,6 +3,7 @@ import sys
|
|||
from time import sleep
|
||||
import typer
|
||||
from pydantic import BaseModel, Json, TypeAdapter
|
||||
from pydantic_core import SchemaValidator, core_schema
|
||||
from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type
|
||||
import json, requests
|
||||
|
||||
|
@ -13,16 +14,12 @@ from examples.agent.utils import collect_functions, load_module
|
|||
from examples.openai.prompting import ToolsPromptStyle
|
||||
from examples.openai.subprocesses import spawn_subprocess
|
||||
|
||||
def _get_params_schema(fn: Callable[[Any], Any], verbose):
|
||||
if isinstance(fn, OpenAPIMethod):
|
||||
return fn.parameters_schema
|
||||
|
||||
# converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||
schema = TypeAdapter(fn).json_schema()
|
||||
# Do NOT call converter.resolve_refs(schema) here. Let the server resolve local refs.
|
||||
if verbose:
|
||||
sys.stderr.write(f'# PARAMS SCHEMA: {json.dumps(schema, indent=2)}\n')
|
||||
return schema
|
||||
def make_call_adapter(ta: TypeAdapter, fn: Callable[..., Any]):
|
||||
args_validator = SchemaValidator(core_schema.call_schema(
|
||||
arguments=ta.core_schema['arguments_schema'],
|
||||
function=fn,
|
||||
))
|
||||
return lambda **kwargs: args_validator.validate_python(kwargs)
|
||||
|
||||
def completion_with_tool_usage(
|
||||
*,
|
||||
|
@ -50,18 +47,28 @@ def completion_with_tool_usage(
|
|||
schema = type_adapter.json_schema()
|
||||
response_format=ResponseFormat(type="json_object", schema=schema)
|
||||
|
||||
tool_map = {fn.__name__: fn for fn in tools}
|
||||
tools_schemas = [
|
||||
Tool(
|
||||
type="function",
|
||||
function=ToolFunction(
|
||||
name=fn.__name__,
|
||||
description=fn.__doc__ or '',
|
||||
parameters=_get_params_schema(fn, verbose=verbose)
|
||||
tool_map = {}
|
||||
tools_schemas = []
|
||||
for fn in tools:
|
||||
if isinstance(fn, OpenAPIMethod):
|
||||
tool_map[fn.__name__] = fn
|
||||
parameters_schema = fn.parameters_schema
|
||||
else:
|
||||
ta = TypeAdapter(fn)
|
||||
tool_map[fn.__name__] = make_call_adapter(ta, fn)
|
||||
parameters_schema = ta.json_schema()
|
||||
if verbose:
|
||||
sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(parameters_schema, indent=2)}\n')
|
||||
tools_schemas.append(
|
||||
Tool(
|
||||
type="function",
|
||||
function=ToolFunction(
|
||||
name=fn.__name__,
|
||||
description=fn.__doc__ or '',
|
||||
parameters=parameters_schema,
|
||||
)
|
||||
)
|
||||
)
|
||||
for fn in tools
|
||||
]
|
||||
|
||||
i = 0
|
||||
while (max_iterations is None or i < max_iterations):
|
||||
|
@ -106,7 +113,7 @@ def completion_with_tool_usage(
|
|||
sys.stdout.write(f'⚙️ {pretty_call}')
|
||||
sys.stdout.flush()
|
||||
tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments)
|
||||
sys.stdout.write(f" -> {tool_result}\n")
|
||||
sys.stdout.write(f" → {tool_result}\n")
|
||||
messages.append(Message(
|
||||
tool_call_id=tool_call.id,
|
||||
role="tool",
|
||||
|
@ -203,6 +210,8 @@ def main(
|
|||
if std_tools:
|
||||
tool_functions.extend(collect_functions(StandardTools))
|
||||
|
||||
sys.stdout.write(f'🛠️ {", ".join(fn.__name__ for fn in tool_functions)}\n')
|
||||
|
||||
response_model: Union[type, Json[Any]] = None #str
|
||||
if format:
|
||||
if format in types:
|
||||
|
|
|
@ -16,7 +16,18 @@ class Duration(BaseModel):
|
|||
years: Optional[int] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.years} years, {self.months} months, {self.days} days, {self.hours} hours, {self.minutes} minutes, {self.seconds} seconds"
|
||||
return ', '.join([
|
||||
x
|
||||
for x in [
|
||||
f"{self.years} years" if self.years else None,
|
||||
f"{self.months} months" if self.months else None,
|
||||
f"{self.days} days" if self.days else None,
|
||||
f"{self.hours} hours" if self.hours else None,
|
||||
f"{self.minutes} minutes" if self.minutes else None,
|
||||
f"{self.seconds} seconds" if self.seconds else None,
|
||||
]
|
||||
if x is not None
|
||||
])
|
||||
|
||||
@property
|
||||
def get_total_seconds(self) -> int:
|
||||
|
@ -36,25 +47,6 @@ class WaitForDuration(BaseModel):
|
|||
sys.stderr.write(f"Waiting for {self.duration}...\n")
|
||||
time.sleep(self.duration.get_total_seconds)
|
||||
|
||||
class WaitForDate(BaseModel):
|
||||
until: date
|
||||
|
||||
def __call__(self):
|
||||
# Get the current date
|
||||
current_date = datetime.date.today()
|
||||
|
||||
if self.until < current_date:
|
||||
raise ValueError("Target date cannot be in the past.")
|
||||
|
||||
time_diff = datetime.datetime.combine(self.until, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min)
|
||||
|
||||
days, seconds = time_diff.days, time_diff.seconds
|
||||
|
||||
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {self.until}...\n")
|
||||
time.sleep(days * 86400 + seconds)
|
||||
sys.stderr.write(f"Reached the target date: {self.until}\n")
|
||||
|
||||
|
||||
class StandardTools:
|
||||
|
||||
@staticmethod
|
||||
|
@ -66,12 +58,32 @@ class StandardTools:
|
|||
return typer.prompt(question)
|
||||
|
||||
@staticmethod
|
||||
def wait(_for: Union[WaitForDuration, WaitForDate]) -> None:
|
||||
def wait_for_duration(duration: Duration) -> None:
|
||||
'Wait for a certain amount of time before continuing.'
|
||||
|
||||
# sys.stderr.write(f"Waiting for {duration}...\n")
|
||||
time.sleep(duration.get_total_seconds)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_date(target_date: date) -> None:
|
||||
f'''
|
||||
Wait until a specific date is reached before continuing.
|
||||
Today's date is {datetime.date.today()}
|
||||
'''
|
||||
Wait for a certain amount of time before continuing.
|
||||
This can be used to wait for a specific duration or until a specific date.
|
||||
'''
|
||||
return _for()
|
||||
|
||||
# Get the current date
|
||||
current_date = datetime.date.today()
|
||||
|
||||
if target_date < current_date:
|
||||
raise ValueError("Target date cannot be in the past.")
|
||||
|
||||
time_diff = datetime.datetime.combine(target_date, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min)
|
||||
|
||||
days, seconds = time_diff.days, time_diff.seconds
|
||||
|
||||
# sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {target_date}...\n")
|
||||
time.sleep(days * 86400 + seconds)
|
||||
# sys.stderr.write(f"Reached the target date: {target_date}\n")
|
||||
|
||||
@staticmethod
|
||||
def say_out_loud(something: str) -> None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue