agent: disable parallel by default

This commit is contained in:
ochafik 2024-03-29 19:22:15 +00:00
parent b4e292ec01
commit d1d86027c4
3 changed files with 45 additions and 7 deletions

View file

@ -108,8 +108,8 @@ def completion_with_tool_usage(
tool_call_id=tool_call.id,
role="tool",
name=tool_call.function.name,
# content=f'{tool_result}',
content=f'{pretty_call} = {tool_result}',
content=f'{tool_result}',
# content=f'{pretty_call} = {tool_result}',
))
else:
assert content
@ -129,7 +129,7 @@ def main(
max_iterations: Optional[int] = 10,
std_tools: Optional[bool] = False,
auth: Optional[str] = None,
parallel_calls: Optional[bool] = True,
parallel_calls: Optional[bool] = False,
verbose: bool = False,
style: Optional[ToolsPromptStyle] = None,

View file

@ -4,7 +4,7 @@ from pydantic import BaseModel, Json, TypeAdapter
class FunctionCall(BaseModel):
name: str
arguments: Dict[str, Any]
arguments: Union[Dict[str, Any], str]
class ToolCall(BaseModel):
id: Optional[str] = None

View file

@ -31,7 +31,7 @@ def main(
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
host: str = "localhost",
port: int = 8080,
parallel_calls: Optional[bool] = True,
parallel_calls: Optional[bool] = False,
style: Optional[ToolsPromptStyle] = None,
auth: Optional[str] = None,
verbose: bool = False,
@ -75,6 +75,44 @@ def main(
atexit.register(server_process.kill)
endpoint = f"http://{server_host}:{server_port}/completions"
# print(chat_template.render([
# Message(**{
# "role": "user",
# "name": None,
# "tool_call_id": None,
# "content": "What is the sum of 2535 squared and 32222000403 then multiplied by one and a half. What's a third of the result?",
# "tool_calls": None
# }),
# Message(**{
# "role": "assistant",
# # "name": None,
# "tool_call_id": None,
# "content": "?",
# "tool_calls": [
# {
# # "id": "call_531873",
# "type": "function",
# "function": {
# "name": "add",
# "arguments": {
# "a": 2535,
# "b": 32222000403
# }
# }
# }
# ]
# }),
# Message(**{
# "role": "tool",
# "name": "add",
# "tool_call_id": "call_531873",
# "content": "32222002938",
# "tool_calls": None
# })
# ], add_generation_prompt=True))
# exit(0)
app = FastAPI()
@app.post("/v1/chat/completions")
@ -95,6 +133,7 @@ def main(
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
parallel_calls=parallel_calls,
tool_style=style,
verbose=verbose,
)
messages = chat_request.messages
@ -103,7 +142,6 @@ def main(
prompt = chat_template.render(messages, add_generation_prompt=True)
if verbose:
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')
# sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n')