From d1d86027c46b19ea241e833af0c4f46bb9eed77e Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 29 Mar 2024 19:22:15 +0000 Subject: [PATCH] agent: disable parallel by default --- examples/agent/agent.py | 6 +++--- examples/openai/api.py | 2 +- examples/openai/server.py | 44 ++++++++++++++++++++++++++++++++++++--- 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/examples/agent/agent.py b/examples/agent/agent.py index 96355d225..1652c4479 100644 --- a/examples/agent/agent.py +++ b/examples/agent/agent.py @@ -108,8 +108,8 @@ def completion_with_tool_usage( tool_call_id=tool_call.id, role="tool", name=tool_call.function.name, - # content=f'{tool_result}', - content=f'{pretty_call} = {tool_result}', + content=f'{tool_result}', + # content=f'{pretty_call} = {tool_result}', )) else: assert content @@ -129,7 +129,7 @@ def main( max_iterations: Optional[int] = 10, std_tools: Optional[bool] = False, auth: Optional[str] = None, - parallel_calls: Optional[bool] = True, + parallel_calls: Optional[bool] = False, verbose: bool = False, style: Optional[ToolsPromptStyle] = None, diff --git a/examples/openai/api.py b/examples/openai/api.py index b95eb17fa..7780d8bc4 100644 --- a/examples/openai/api.py +++ b/examples/openai/api.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Json, TypeAdapter class FunctionCall(BaseModel): name: str - arguments: Dict[str, Any] + arguments: Union[Dict[str, Any], str] class ToolCall(BaseModel): id: Optional[str] = None diff --git a/examples/openai/server.py b/examples/openai/server.py index a8abe8c8a..6d19f12f1 100644 --- a/examples/openai/server.py +++ b/examples/openai/server.py @@ -31,7 +31,7 @@ def main( # model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None, host: str = "localhost", port: int = 8080, - parallel_calls: Optional[bool] = True, + parallel_calls: Optional[bool] = False, style: Optional[ToolsPromptStyle] = None, auth: Optional[str] = None, verbose: bool = False, @@ -75,6 +75,44 @@ def main( atexit.register(server_process.kill) endpoint = f"http://{server_host}:{server_port}/completions" + + # print(chat_template.render([ + # Message(**{ + # "role": "user", + # "name": None, + # "tool_call_id": None, + # "content": "What is the sum of 2535 squared and 32222000403 then multiplied by one and a half. What's a third of the result?", + # "tool_calls": None + # }), + # Message(**{ + # "role": "assistant", + # # "name": None, + # "tool_call_id": None, + # "content": "?", + # "tool_calls": [ + # { + # # "id": "call_531873", + # "type": "function", + # "function": { + # "name": "add", + # "arguments": { + # "a": 2535, + # "b": 32222000403 + # } + # } + # } + # ] + # }), + # Message(**{ + # "role": "tool", + # "name": "add", + # "tool_call_id": "call_531873", + # "content": "32222002938", + # "tool_calls": None + # }) + # ], add_generation_prompt=True)) + # exit(0) + app = FastAPI() @app.post("/v1/chat/completions") @@ -95,6 +133,7 @@ def main( ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools), parallel_calls=parallel_calls, tool_style=style, + verbose=verbose, ) messages = chat_request.messages @@ -102,8 +141,7 @@ def main( messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt) prompt = chat_template.render(messages, add_generation_prompt=True) - - + if verbose: sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n') # sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n')