agent
: allow interactive chat by default, and don't reuse sessions
This commit is contained in:
parent
6f2191d99e
commit
26e76f9704
1 changed files with 22 additions and 10 deletions
|
@ -77,7 +77,7 @@ class OpenAPIMethod:
|
||||||
if components:
|
if components:
|
||||||
self.parameters_schema['components'] = components
|
self.parameters_schema['components'] = components
|
||||||
|
|
||||||
async def __call__(self, session: aiohttp.ClientSession, **kwargs):
|
async def __call__(self, **kwargs):
|
||||||
if self.body:
|
if self.body:
|
||||||
body = kwargs.pop(self.body['name'], None)
|
body = kwargs.pop(self.body['name'], None)
|
||||||
if self.body['required']:
|
if self.body['required']:
|
||||||
|
@ -96,9 +96,10 @@ class OpenAPIMethod:
|
||||||
|
|
||||||
params = '&'.join(f'{name}={urllib.parse.quote(str(value))}' for name, value in query_params.items() if value is not None)
|
params = '&'.join(f'{name}={urllib.parse.quote(str(value))}' for name, value in query_params.items() if value is not None)
|
||||||
url = f'{self.url}?{params}'
|
url = f'{self.url}?{params}'
|
||||||
async with session.post(url, json=body) as response:
|
async with aiohttp.ClientSession() as session:
|
||||||
response.raise_for_status()
|
async with session.post(url, json=body) as response:
|
||||||
response_json = await response.json()
|
response.raise_for_status()
|
||||||
|
response_json = await response.json()
|
||||||
|
|
||||||
return response_json
|
return response_json
|
||||||
|
|
||||||
|
@ -131,6 +132,7 @@ async def discover_tools(tool_endpoints: list[str], logger) -> tuple[dict, list]
|
||||||
|
|
||||||
return tool_map, tools
|
return tool_map, tools
|
||||||
|
|
||||||
|
|
||||||
def typer_async_workaround():
|
def typer_async_workaround():
|
||||||
'Adapted from https://github.com/fastapi/typer/issues/950#issuecomment-2351076467'
|
'Adapted from https://github.com/fastapi/typer/issues/950#issuecomment-2351076467'
|
||||||
def decorator(f):
|
def decorator(f):
|
||||||
|
@ -149,6 +151,7 @@ async def main(
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
cache_prompt: bool = True,
|
cache_prompt: bool = True,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
interactive: bool = True,
|
||||||
openai: bool = False,
|
openai: bool = False,
|
||||||
endpoint: Optional[str] = None,
|
endpoint: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
@ -180,7 +183,7 @@ async def main(
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Authorization': f'Bearer {api_key}'
|
'Authorization': f'Bearer {api_key}'
|
||||||
}
|
}
|
||||||
async with aiohttp.ClientSession(headers=headers) as session:
|
async def run_turn():
|
||||||
for i in range(max_iterations or sys.maxsize):
|
for i in range(max_iterations or sys.maxsize):
|
||||||
url = f'{endpoint}chat/completions'
|
url = f'{endpoint}chat/completions'
|
||||||
payload = dict(
|
payload = dict(
|
||||||
|
@ -195,10 +198,11 @@ async def main(
|
||||||
)) # type: ignore
|
)) # type: ignore
|
||||||
|
|
||||||
logger.debug('Calling %s with %s', url, json.dumps(payload, indent=2))
|
logger.debug('Calling %s with %s', url, json.dumps(payload, indent=2))
|
||||||
async with session.post(url, json=payload) as response:
|
async with aiohttp.ClientSession(headers=headers) as session:
|
||||||
logger.debug('Response: %s', response)
|
async with session.post(url, json=payload) as response:
|
||||||
response.raise_for_status()
|
logger.debug('Response: %s', response)
|
||||||
response = await response.json()
|
response.raise_for_status()
|
||||||
|
response = await response.json()
|
||||||
|
|
||||||
assert len(response['choices']) == 1
|
assert len(response['choices']) == 1
|
||||||
choice = response['choices'][0]
|
choice = response['choices'][0]
|
||||||
|
@ -216,7 +220,7 @@ async def main(
|
||||||
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
||||||
logger.info(f'⚙️ {pretty_call}')
|
logger.info(f'⚙️ {pretty_call}')
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
tool_result = await tool_map[name](session, **args)
|
tool_result = await tool_map[name](**args)
|
||||||
tool_result_str = json.dumps(tool_result)
|
tool_result_str = json.dumps(tool_result)
|
||||||
logger.info(' → %d chars', len(tool_result_str))
|
logger.info(' → %d chars', len(tool_result_str))
|
||||||
logger.debug('%s', tool_result_str)
|
logger.debug('%s', tool_result_str)
|
||||||
|
@ -233,5 +237,13 @@ async def main(
|
||||||
if max_iterations is not None:
|
if max_iterations is not None:
|
||||||
raise Exception(f'Failed to get a valid response after {max_iterations} tool calls')
|
raise Exception(f'Failed to get a valid response after {max_iterations} tool calls')
|
||||||
|
|
||||||
|
while interactive:
|
||||||
|
await run_turn()
|
||||||
|
messages.append(dict(
|
||||||
|
role='user',
|
||||||
|
content=input('💬 ')
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
typer.run(main)
|
typer.run(main)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue