From ce2fb0155f2952e736fe36bbdd943bd4cc6746bd Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 29 Mar 2024 16:40:23 +0000 Subject: [PATCH] agent: add --allow_parallel_calls --- examples/agent/README.md | 2 ++ examples/agent/agent.py | 12 ++++++++++-- examples/openai/prompting.py | 16 +++++++++------- examples/openai/server.py | 11 ++++++++--- 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/examples/agent/README.md b/examples/agent/README.md index c774bfb31..3ae0fc7d1 100644 --- a/examples/agent/README.md +++ b/examples/agent/README.md @@ -157,6 +157,8 @@ REQUIREMENTS_FILE=<( cat examples/agents/hermes_function_calling/requirements.tx ## TODO +- Wait for spawned servers to be heathly + - Add model URL / HF loading support - Add Embedding endpoint + storage / retrieval tools (Faiss? ScaNN?), or spontaneous RAG diff --git a/examples/agent/agent.py b/examples/agent/agent.py index bbe8223e2..78c1f8252 100644 --- a/examples/agent/agent.py +++ b/examples/agent/agent.py @@ -128,6 +128,7 @@ def main( max_iterations: Optional[int] = 10, std_tools: Optional[bool] = False, auth: Optional[str] = None, + allow_parallel_calls: Optional[bool] = False, verbose: bool = False, model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", @@ -135,6 +136,8 @@ def main( context_length: Optional[int] = None, # endpoint: str = 'http://localhost:8080/v1/chat/completions', + greedy: Optional[bool] = True, + n_predict: Optional[int] = 1000, top_k: Optional[int] = None, top_p: Optional[float] = None, @@ -157,6 +160,10 @@ def main( n_probs: Optional[int] = None, min_keep: Optional[int] = None, ): + if greedy: + top_k = 1 + top_p = 0.0 + if not endpoint: server_port = 8080 server_host = 'localhost' @@ -167,9 +174,10 @@ def main( "python", "-m", "examples.openai.server", "--model", model, *(['--verbose'] if verbose else []), - *([f'--context_length={context_length}'] if context_length else []), + *(['--allow-parallel-calls'] if allow_parallel_calls else []), + *(['--context-length={context_length}'] if context_length else []), + *([]) ] - print(cmd) server_process = subprocess.Popen(cmd, stdout=sys.stderr) atexit.register(server_process.kill) sleep(5) diff --git a/examples/openai/prompting.py b/examples/openai/prompting.py index a6d71e36f..3edea651d 100644 --- a/examples/openai/prompting.py +++ b/examples/openai/prompting.py @@ -320,8 +320,8 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler): ) class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler): - def __init__(self, args: ChatHandlerArgs): - super().__init__(args, escapes_underscores=False, allow_parallel_calls=False) + def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool): + super().__init__(args, escapes_underscores=False, allow_parallel_calls=allow_parallel_calls) # Hackily import https://github.com/NousResearch/Hermes-Function-Calling path = str(Path(__file__).parent / "hermes_function_calling") @@ -433,7 +433,7 @@ class FunctionaryToolsChatHandler(ChatHandler): content = '\n'.join(text_content).strip() return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None) -def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls=False): +def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls): return { "type": "object", "properties": { @@ -474,7 +474,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls } class BespokeToolsChatHandler(ChatHandler): - def __init__(self, args: ChatHandlerArgs): + def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool): super().__init__(args) # args.response_schema = args.response_schema or {} @@ -496,7 +496,8 @@ class BespokeToolsChatHandler(ChatHandler): } for tool in self.args.tools ] - } + }, + allow_parallel_calls=allow_parallel_calls, ), '', ) @@ -523,7 +524,8 @@ class BespokeToolsChatHandler(ChatHandler): } }, "required": ["name", "arguments"] - } + }, + allow_parallel_calls=allow_parallel_calls, ) ), ]) @@ -589,7 +591,7 @@ def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatH elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL: return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls) elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE: - return BespokeToolsChatHandler(args) + return BespokeToolsChatHandler(args, allow_parallel_calls=allow_parallel_calls) elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO: return Hermes2ProToolsChatHandler(args) else: diff --git a/examples/openai/server.py b/examples/openai/server.py index d2a3aea2d..962796324 100644 --- a/examples/openai/server.py +++ b/examples/openai/server.py @@ -31,6 +31,7 @@ def main( # model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None, host: str = "localhost", port: int = 8080, + allow_parallel_calls: Optional[bool] = False, auth: Optional[str] = None, verbose: bool = False, context_length: Optional[int] = None, @@ -61,14 +62,15 @@ def main( if verbose: sys.stderr.write(f"# Starting C++ server with model {model} on {server_host}:{server_port}\n") - server_process = subprocess.Popen([ + cmd = [ "./server", "-m", model, "--host", server_host, "--port", f'{server_port}', # TODO: pass these from JSON / BaseSettings? '-ctk', 'q4_0', '-ctv', 'f16', "-c", f"{context_length}", *([] if verbose else ["--log-disable"]), - ], stdout=sys.stderr) + ] + server_process = subprocess.Popen(cmd, stdout=sys.stderr) atexit.register(server_process.kill) endpoint = f"http://{server_host}:{server_port}/completions" @@ -88,7 +90,10 @@ def main( else: response_schema = None - chat_handler = get_chat_handler(ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools)) + chat_handler = get_chat_handler( + ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools), + allow_parallel_calls=allow_parallel_calls + ) messages = chat_request.messages if chat_handler.output_format_prompt: