tool-call
: fix agent type lints
This commit is contained in:
parent
1e5c0e747e
commit
9295ca95db
4 changed files with 24 additions and 25 deletions
|
@ -11,12 +11,13 @@
|
||||||
# ///
|
# ///
|
||||||
import json
|
import json
|
||||||
import openai
|
import openai
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import requests
|
import requests
|
||||||
import sys
|
import sys
|
||||||
import typer
|
import typer
|
||||||
from typing import Annotated, List, Optional
|
from typing import Annotated, Optional
|
||||||
import urllib
|
import urllib.parse
|
||||||
|
|
||||||
|
|
||||||
class OpenAPIMethod:
|
class OpenAPIMethod:
|
||||||
|
@ -94,7 +95,7 @@ class OpenAPIMethod:
|
||||||
def main(
|
def main(
|
||||||
goal: Annotated[str, typer.Option()],
|
goal: Annotated[str, typer.Option()],
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
tool_endpoint: Optional[List[str]] = None,
|
tool_endpoint: Optional[list[str]] = None,
|
||||||
format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
|
format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
|
||||||
max_iterations: Optional[int] = 10,
|
max_iterations: Optional[int] = 10,
|
||||||
parallel_calls: Optional[bool] = False,
|
parallel_calls: Optional[bool] = False,
|
||||||
|
@ -134,8 +135,8 @@ def main(
|
||||||
|
|
||||||
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
|
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
|
||||||
|
|
||||||
messages = [
|
messages: list[ChatCompletionMessageParam] = [
|
||||||
dict(
|
ChatCompletionUserMessageParam(
|
||||||
role="user",
|
role="user",
|
||||||
content=goal,
|
content=goal,
|
||||||
)
|
)
|
||||||
|
@ -158,7 +159,8 @@ def main(
|
||||||
|
|
||||||
content = choice.message.content
|
content = choice.message.content
|
||||||
if choice.finish_reason == "tool_calls":
|
if choice.finish_reason == "tool_calls":
|
||||||
messages.append(choice.message)
|
messages.append(choice.message) # type: ignore
|
||||||
|
assert choice.message.tool_calls
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
if content:
|
if content:
|
||||||
print(f'💭 {content}')
|
print(f'💭 {content}')
|
||||||
|
@ -169,11 +171,11 @@ def main(
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
tool_result = tool_map[tool_call.function.name](**args)
|
tool_result = tool_map[tool_call.function.name](**args)
|
||||||
sys.stdout.write(f" → {tool_result}\n")
|
sys.stdout.write(f" → {tool_result}\n")
|
||||||
messages.append(dict(
|
messages.append(ChatCompletionToolMessageParam(
|
||||||
tool_call_id=tool_call.id,
|
tool_call_id=tool_call.id,
|
||||||
role="tool",
|
role="tool",
|
||||||
name=tool_call.function.name,
|
# name=tool_call.function.name,
|
||||||
content=f'{tool_result}',
|
content=json.dumps(tool_result),
|
||||||
# content=f'{pretty_call} = {tool_result}',
|
# content=f'{pretty_call} = {tool_result}',
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -59,8 +59,8 @@ def main(files: List[str], host: str = '0.0.0.0', port: int = 8000):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
vt = type(v)
|
vt = type(v)
|
||||||
if vt.__module__ == 'langchain_core.tools' and vt.__name__.endswith('Tool') and hasattr(v, 'func') and callable(v.func):
|
if vt.__module__ == 'langchain_core.tools' and vt.__name__.endswith('Tool') and hasattr(v, 'func') and callable(func := getattr(v, 'func')):
|
||||||
v = v.func
|
v = func
|
||||||
|
|
||||||
print(f'INFO: Binding /{k}')
|
print(f'INFO: Binding /{k}')
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,13 +1,10 @@
|
||||||
from datetime import date
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import typer
|
|
||||||
from typing import Union, Optional, Dict
|
|
||||||
import types
|
import types
|
||||||
|
from typing import Union, Optional, Dict
|
||||||
|
|
||||||
|
|
||||||
class Duration(BaseModel):
|
class Duration(BaseModel):
|
||||||
|
@ -58,7 +55,7 @@ def wait_for_duration(duration: Duration) -> None:
|
||||||
time.sleep(duration.get_total_seconds)
|
time.sleep(duration.get_total_seconds)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait_for_date(target_date: date) -> None:
|
def wait_for_date(target_date: datetime.date) -> None:
|
||||||
f'''
|
f'''
|
||||||
Wait until a specific date is reached before continuing.
|
Wait until a specific date is reached before continuing.
|
||||||
Today's date is {datetime.date.today()}
|
Today's date is {datetime.date.today()}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue