agent: fix wait --std-tools
This commit is contained in:
parent
89dcc062a4
commit
0120f7cc95
2 changed files with 67 additions and 46 deletions
|
@ -3,6 +3,7 @@ import sys
|
||||||
from time import sleep
|
from time import sleep
|
||||||
import typer
|
import typer
|
||||||
from pydantic import BaseModel, Json, TypeAdapter
|
from pydantic import BaseModel, Json, TypeAdapter
|
||||||
|
from pydantic_core import SchemaValidator, core_schema
|
||||||
from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type
|
from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type
|
||||||
import json, requests
|
import json, requests
|
||||||
|
|
||||||
|
@ -13,16 +14,12 @@ from examples.agent.utils import collect_functions, load_module
|
||||||
from examples.openai.prompting import ToolsPromptStyle
|
from examples.openai.prompting import ToolsPromptStyle
|
||||||
from examples.openai.subprocesses import spawn_subprocess
|
from examples.openai.subprocesses import spawn_subprocess
|
||||||
|
|
||||||
def _get_params_schema(fn: Callable[[Any], Any], verbose):
|
def make_call_adapter(ta: TypeAdapter, fn: Callable[..., Any]):
|
||||||
if isinstance(fn, OpenAPIMethod):
|
args_validator = SchemaValidator(core_schema.call_schema(
|
||||||
return fn.parameters_schema
|
arguments=ta.core_schema['arguments_schema'],
|
||||||
|
function=fn,
|
||||||
# converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
))
|
||||||
schema = TypeAdapter(fn).json_schema()
|
return lambda **kwargs: args_validator.validate_python(kwargs)
|
||||||
# Do NOT call converter.resolve_refs(schema) here. Let the server resolve local refs.
|
|
||||||
if verbose:
|
|
||||||
sys.stderr.write(f'# PARAMS SCHEMA: {json.dumps(schema, indent=2)}\n')
|
|
||||||
return schema
|
|
||||||
|
|
||||||
def completion_with_tool_usage(
|
def completion_with_tool_usage(
|
||||||
*,
|
*,
|
||||||
|
@ -50,18 +47,28 @@ def completion_with_tool_usage(
|
||||||
schema = type_adapter.json_schema()
|
schema = type_adapter.json_schema()
|
||||||
response_format=ResponseFormat(type="json_object", schema=schema)
|
response_format=ResponseFormat(type="json_object", schema=schema)
|
||||||
|
|
||||||
tool_map = {fn.__name__: fn for fn in tools}
|
tool_map = {}
|
||||||
tools_schemas = [
|
tools_schemas = []
|
||||||
Tool(
|
for fn in tools:
|
||||||
type="function",
|
if isinstance(fn, OpenAPIMethod):
|
||||||
function=ToolFunction(
|
tool_map[fn.__name__] = fn
|
||||||
name=fn.__name__,
|
parameters_schema = fn.parameters_schema
|
||||||
description=fn.__doc__ or '',
|
else:
|
||||||
parameters=_get_params_schema(fn, verbose=verbose)
|
ta = TypeAdapter(fn)
|
||||||
|
tool_map[fn.__name__] = make_call_adapter(ta, fn)
|
||||||
|
parameters_schema = ta.json_schema()
|
||||||
|
if verbose:
|
||||||
|
sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(parameters_schema, indent=2)}\n')
|
||||||
|
tools_schemas.append(
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=ToolFunction(
|
||||||
|
name=fn.__name__,
|
||||||
|
description=fn.__doc__ or '',
|
||||||
|
parameters=parameters_schema,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for fn in tools
|
|
||||||
]
|
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
while (max_iterations is None or i < max_iterations):
|
while (max_iterations is None or i < max_iterations):
|
||||||
|
@ -106,7 +113,7 @@ def completion_with_tool_usage(
|
||||||
sys.stdout.write(f'⚙️ {pretty_call}')
|
sys.stdout.write(f'⚙️ {pretty_call}')
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments)
|
tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments)
|
||||||
sys.stdout.write(f" -> {tool_result}\n")
|
sys.stdout.write(f" → {tool_result}\n")
|
||||||
messages.append(Message(
|
messages.append(Message(
|
||||||
tool_call_id=tool_call.id,
|
tool_call_id=tool_call.id,
|
||||||
role="tool",
|
role="tool",
|
||||||
|
@ -203,6 +210,8 @@ def main(
|
||||||
if std_tools:
|
if std_tools:
|
||||||
tool_functions.extend(collect_functions(StandardTools))
|
tool_functions.extend(collect_functions(StandardTools))
|
||||||
|
|
||||||
|
sys.stdout.write(f'🛠️ {", ".join(fn.__name__ for fn in tool_functions)}\n')
|
||||||
|
|
||||||
response_model: Union[type, Json[Any]] = None #str
|
response_model: Union[type, Json[Any]] = None #str
|
||||||
if format:
|
if format:
|
||||||
if format in types:
|
if format in types:
|
||||||
|
|
|
@ -16,7 +16,18 @@ class Duration(BaseModel):
|
||||||
years: Optional[int] = None
|
years: Optional[int] = None
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"{self.years} years, {self.months} months, {self.days} days, {self.hours} hours, {self.minutes} minutes, {self.seconds} seconds"
|
return ', '.join([
|
||||||
|
x
|
||||||
|
for x in [
|
||||||
|
f"{self.years} years" if self.years else None,
|
||||||
|
f"{self.months} months" if self.months else None,
|
||||||
|
f"{self.days} days" if self.days else None,
|
||||||
|
f"{self.hours} hours" if self.hours else None,
|
||||||
|
f"{self.minutes} minutes" if self.minutes else None,
|
||||||
|
f"{self.seconds} seconds" if self.seconds else None,
|
||||||
|
]
|
||||||
|
if x is not None
|
||||||
|
])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def get_total_seconds(self) -> int:
|
def get_total_seconds(self) -> int:
|
||||||
|
@ -36,25 +47,6 @@ class WaitForDuration(BaseModel):
|
||||||
sys.stderr.write(f"Waiting for {self.duration}...\n")
|
sys.stderr.write(f"Waiting for {self.duration}...\n")
|
||||||
time.sleep(self.duration.get_total_seconds)
|
time.sleep(self.duration.get_total_seconds)
|
||||||
|
|
||||||
class WaitForDate(BaseModel):
|
|
||||||
until: date
|
|
||||||
|
|
||||||
def __call__(self):
|
|
||||||
# Get the current date
|
|
||||||
current_date = datetime.date.today()
|
|
||||||
|
|
||||||
if self.until < current_date:
|
|
||||||
raise ValueError("Target date cannot be in the past.")
|
|
||||||
|
|
||||||
time_diff = datetime.datetime.combine(self.until, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min)
|
|
||||||
|
|
||||||
days, seconds = time_diff.days, time_diff.seconds
|
|
||||||
|
|
||||||
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {self.until}...\n")
|
|
||||||
time.sleep(days * 86400 + seconds)
|
|
||||||
sys.stderr.write(f"Reached the target date: {self.until}\n")
|
|
||||||
|
|
||||||
|
|
||||||
class StandardTools:
|
class StandardTools:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -66,12 +58,32 @@ class StandardTools:
|
||||||
return typer.prompt(question)
|
return typer.prompt(question)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait(_for: Union[WaitForDuration, WaitForDate]) -> None:
|
def wait_for_duration(duration: Duration) -> None:
|
||||||
|
'Wait for a certain amount of time before continuing.'
|
||||||
|
|
||||||
|
# sys.stderr.write(f"Waiting for {duration}...\n")
|
||||||
|
time.sleep(duration.get_total_seconds)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def wait_for_date(target_date: date) -> None:
|
||||||
|
f'''
|
||||||
|
Wait until a specific date is reached before continuing.
|
||||||
|
Today's date is {datetime.date.today()}
|
||||||
'''
|
'''
|
||||||
Wait for a certain amount of time before continuing.
|
|
||||||
This can be used to wait for a specific duration or until a specific date.
|
# Get the current date
|
||||||
'''
|
current_date = datetime.date.today()
|
||||||
return _for()
|
|
||||||
|
if target_date < current_date:
|
||||||
|
raise ValueError("Target date cannot be in the past.")
|
||||||
|
|
||||||
|
time_diff = datetime.datetime.combine(target_date, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min)
|
||||||
|
|
||||||
|
days, seconds = time_diff.days, time_diff.seconds
|
||||||
|
|
||||||
|
# sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {target_date}...\n")
|
||||||
|
time.sleep(days * 86400 + seconds)
|
||||||
|
# sys.stderr.write(f"Reached the target date: {target_date}\n")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def say_out_loud(something: str) -> None:
|
def say_out_loud(something: str) -> None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue