agent: disable parallel by default

This commit is contained in:
ochafik 2024-03-29 19:22:15 +00:00
parent b4e292ec01
commit d1d86027c4
3 changed files with 45 additions and 7 deletions

View file

@ -108,8 +108,8 @@ def completion_with_tool_usage(
tool_call_id=tool_call.id, tool_call_id=tool_call.id,
role="tool", role="tool",
name=tool_call.function.name, name=tool_call.function.name,
# content=f'{tool_result}', content=f'{tool_result}',
content=f'{pretty_call} = {tool_result}', # content=f'{pretty_call} = {tool_result}',
)) ))
else: else:
assert content assert content
@ -129,7 +129,7 @@ def main(
max_iterations: Optional[int] = 10, max_iterations: Optional[int] = 10,
std_tools: Optional[bool] = False, std_tools: Optional[bool] = False,
auth: Optional[str] = None, auth: Optional[str] = None,
parallel_calls: Optional[bool] = True, parallel_calls: Optional[bool] = False,
verbose: bool = False, verbose: bool = False,
style: Optional[ToolsPromptStyle] = None, style: Optional[ToolsPromptStyle] = None,

View file

@ -4,7 +4,7 @@ from pydantic import BaseModel, Json, TypeAdapter
class FunctionCall(BaseModel): class FunctionCall(BaseModel):
name: str name: str
arguments: Dict[str, Any] arguments: Union[Dict[str, Any], str]
class ToolCall(BaseModel): class ToolCall(BaseModel):
id: Optional[str] = None id: Optional[str] = None

View file

@ -31,7 +31,7 @@ def main(
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None, # model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
host: str = "localhost", host: str = "localhost",
port: int = 8080, port: int = 8080,
parallel_calls: Optional[bool] = True, parallel_calls: Optional[bool] = False,
style: Optional[ToolsPromptStyle] = None, style: Optional[ToolsPromptStyle] = None,
auth: Optional[str] = None, auth: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
@ -75,6 +75,44 @@ def main(
atexit.register(server_process.kill) atexit.register(server_process.kill)
endpoint = f"http://{server_host}:{server_port}/completions" endpoint = f"http://{server_host}:{server_port}/completions"
# print(chat_template.render([
# Message(**{
# "role": "user",
# "name": None,
# "tool_call_id": None,
# "content": "What is the sum of 2535 squared and 32222000403 then multiplied by one and a half. What's a third of the result?",
# "tool_calls": None
# }),
# Message(**{
# "role": "assistant",
# # "name": None,
# "tool_call_id": None,
# "content": "?",
# "tool_calls": [
# {
# # "id": "call_531873",
# "type": "function",
# "function": {
# "name": "add",
# "arguments": {
# "a": 2535,
# "b": 32222000403
# }
# }
# }
# ]
# }),
# Message(**{
# "role": "tool",
# "name": "add",
# "tool_call_id": "call_531873",
# "content": "32222002938",
# "tool_calls": None
# })
# ], add_generation_prompt=True))
# exit(0)
app = FastAPI() app = FastAPI()
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
@ -95,6 +133,7 @@ def main(
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools), ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
parallel_calls=parallel_calls, parallel_calls=parallel_calls,
tool_style=style, tool_style=style,
verbose=verbose,
) )
messages = chat_request.messages messages = chat_request.messages
@ -102,8 +141,7 @@ def main(
messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt) messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
prompt = chat_template.render(messages, add_generation_prompt=True) prompt = chat_template.render(messages, add_generation_prompt=True)
if verbose: if verbose:
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n') sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')
# sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n') # sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n')