tool-call: fix agent type lints

This commit is contained in:
ochafik 2024-09-27 03:53:56 +01:00
parent 1e5c0e747e
commit 9295ca95db
4 changed files with 24 additions and 25 deletions

View file

@ -11,12 +11,13 @@
# /// # ///
import json import json
import openai import openai
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
from pydantic import BaseModel from pydantic import BaseModel
import requests import requests
import sys import sys
import typer import typer
from typing import Annotated, List, Optional from typing import Annotated, Optional
import urllib import urllib.parse
class OpenAPIMethod: class OpenAPIMethod:
@ -94,7 +95,7 @@ class OpenAPIMethod:
def main( def main(
goal: Annotated[str, typer.Option()], goal: Annotated[str, typer.Option()],
api_key: Optional[str] = None, api_key: Optional[str] = None,
tool_endpoint: Optional[List[str]] = None, tool_endpoint: Optional[list[str]] = None,
format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None, format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
max_iterations: Optional[int] = 10, max_iterations: Optional[int] = 10,
parallel_calls: Optional[bool] = False, parallel_calls: Optional[bool] = False,
@ -134,8 +135,8 @@ def main(
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n') sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
messages = [ messages: list[ChatCompletionMessageParam] = [
dict( ChatCompletionUserMessageParam(
role="user", role="user",
content=goal, content=goal,
) )
@ -158,7 +159,8 @@ def main(
content = choice.message.content content = choice.message.content
if choice.finish_reason == "tool_calls": if choice.finish_reason == "tool_calls":
messages.append(choice.message) messages.append(choice.message) # type: ignore
assert choice.message.tool_calls
for tool_call in choice.message.tool_calls: for tool_call in choice.message.tool_calls:
if content: if content:
print(f'💭 {content}') print(f'💭 {content}')
@ -169,11 +171,11 @@ def main(
sys.stdout.flush() sys.stdout.flush()
tool_result = tool_map[tool_call.function.name](**args) tool_result = tool_map[tool_call.function.name](**args)
sys.stdout.write(f"{tool_result}\n") sys.stdout.write(f"{tool_result}\n")
messages.append(dict( messages.append(ChatCompletionToolMessageParam(
tool_call_id=tool_call.id, tool_call_id=tool_call.id,
role="tool", role="tool",
name=tool_call.function.name, # name=tool_call.function.name,
content=f'{tool_result}', content=json.dumps(tool_result),
# content=f'{pretty_call} = {tool_result}', # content=f'{pretty_call} = {tool_result}',
)) ))
else: else:

View file

@ -59,8 +59,8 @@ def main(files: List[str], host: str = '0.0.0.0', port: int = 8000):
continue continue
vt = type(v) vt = type(v)
if vt.__module__ == 'langchain_core.tools' and vt.__name__.endswith('Tool') and hasattr(v, 'func') and callable(v.func): if vt.__module__ == 'langchain_core.tools' and vt.__name__.endswith('Tool') and hasattr(v, 'func') and callable(func := getattr(v, 'func')):
v = v.func v = func
print(f'INFO: Binding /{k}') print(f'INFO: Binding /{k}')
try: try:

View file

@ -1,13 +1,10 @@
from datetime import date
import datetime import datetime
import json import json
from pydantic import BaseModel from pydantic import BaseModel
import subprocess
import sys import sys
import time import time
import typer
from typing import Union, Optional, Dict
import types import types
from typing import Union, Optional, Dict
class Duration(BaseModel): class Duration(BaseModel):
@ -58,7 +55,7 @@ def wait_for_duration(duration: Duration) -> None:
time.sleep(duration.get_total_seconds) time.sleep(duration.get_total_seconds)
@staticmethod @staticmethod
def wait_for_date(target_date: date) -> None: def wait_for_date(target_date: datetime.date) -> None:
f''' f'''
Wait until a specific date is reached before continuing. Wait until a specific date is reached before continuing.
Today's date is {datetime.date.today()} Today's date is {datetime.date.today()}