agent: add --allow_parallel_calls

This commit is contained in:
ochafik 2024-03-29 16:40:23 +00:00
parent c340e8cd3b
commit ce2fb0155f
4 changed files with 29 additions and 12 deletions

View file

@ -157,6 +157,8 @@ REQUIREMENTS_FILE=<( cat examples/agents/hermes_function_calling/requirements.tx
## TODO ## TODO
- Wait for spawned servers to be heathly
- Add model URL / HF loading support - Add model URL / HF loading support
- Add Embedding endpoint + storage / retrieval tools (Faiss? ScaNN?), or spontaneous RAG - Add Embedding endpoint + storage / retrieval tools (Faiss? ScaNN?), or spontaneous RAG

View file

@ -128,6 +128,7 @@ def main(
max_iterations: Optional[int] = 10, max_iterations: Optional[int] = 10,
std_tools: Optional[bool] = False, std_tools: Optional[bool] = False,
auth: Optional[str] = None, auth: Optional[str] = None,
allow_parallel_calls: Optional[bool] = False,
verbose: bool = False, verbose: bool = False,
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
@ -135,6 +136,8 @@ def main(
context_length: Optional[int] = None, context_length: Optional[int] = None,
# endpoint: str = 'http://localhost:8080/v1/chat/completions', # endpoint: str = 'http://localhost:8080/v1/chat/completions',
greedy: Optional[bool] = True,
n_predict: Optional[int] = 1000, n_predict: Optional[int] = 1000,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
@ -157,6 +160,10 @@ def main(
n_probs: Optional[int] = None, n_probs: Optional[int] = None,
min_keep: Optional[int] = None, min_keep: Optional[int] = None,
): ):
if greedy:
top_k = 1
top_p = 0.0
if not endpoint: if not endpoint:
server_port = 8080 server_port = 8080
server_host = 'localhost' server_host = 'localhost'
@ -167,9 +174,10 @@ def main(
"python", "-m", "examples.openai.server", "python", "-m", "examples.openai.server",
"--model", model, "--model", model,
*(['--verbose'] if verbose else []), *(['--verbose'] if verbose else []),
*([f'--context_length={context_length}'] if context_length else []), *(['--allow-parallel-calls'] if allow_parallel_calls else []),
*(['--context-length={context_length}'] if context_length else []),
*([])
] ]
print(cmd)
server_process = subprocess.Popen(cmd, stdout=sys.stderr) server_process = subprocess.Popen(cmd, stdout=sys.stderr)
atexit.register(server_process.kill) atexit.register(server_process.kill)
sleep(5) sleep(5)

View file

@ -320,8 +320,8 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
) )
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler): class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs): def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
super().__init__(args, escapes_underscores=False, allow_parallel_calls=False) super().__init__(args, escapes_underscores=False, allow_parallel_calls=allow_parallel_calls)
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling # Hackily import https://github.com/NousResearch/Hermes-Function-Calling
path = str(Path(__file__).parent / "hermes_function_calling") path = str(Path(__file__).parent / "hermes_function_calling")
@ -433,7 +433,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
content = '\n'.join(text_content).strip() content = '\n'.join(text_content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None) return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls=False): def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls):
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
@ -474,7 +474,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
} }
class BespokeToolsChatHandler(ChatHandler): class BespokeToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs): def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
super().__init__(args) super().__init__(args)
# args.response_schema = args.response_schema or {} # args.response_schema = args.response_schema or {}
@ -496,7 +496,8 @@ class BespokeToolsChatHandler(ChatHandler):
} }
for tool in self.args.tools for tool in self.args.tools
] ]
} },
allow_parallel_calls=allow_parallel_calls,
), ),
'', '',
) )
@ -523,7 +524,8 @@ class BespokeToolsChatHandler(ChatHandler):
} }
}, },
"required": ["name", "arguments"] "required": ["name", "arguments"]
} },
allow_parallel_calls=allow_parallel_calls,
) )
), ),
]) ])
@ -589,7 +591,7 @@ def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatH
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls) return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
return BespokeToolsChatHandler(args) return BespokeToolsChatHandler(args, allow_parallel_calls=allow_parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
return Hermes2ProToolsChatHandler(args) return Hermes2ProToolsChatHandler(args)
else: else:

View file

@ -31,6 +31,7 @@ def main(
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None, # model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
host: str = "localhost", host: str = "localhost",
port: int = 8080, port: int = 8080,
allow_parallel_calls: Optional[bool] = False,
auth: Optional[str] = None, auth: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
context_length: Optional[int] = None, context_length: Optional[int] = None,
@ -61,14 +62,15 @@ def main(
if verbose: if verbose:
sys.stderr.write(f"# Starting C++ server with model {model} on {server_host}:{server_port}\n") sys.stderr.write(f"# Starting C++ server with model {model} on {server_host}:{server_port}\n")
server_process = subprocess.Popen([ cmd = [
"./server", "-m", model, "./server", "-m", model,
"--host", server_host, "--port", f'{server_port}', "--host", server_host, "--port", f'{server_port}',
# TODO: pass these from JSON / BaseSettings? # TODO: pass these from JSON / BaseSettings?
'-ctk', 'q4_0', '-ctv', 'f16', '-ctk', 'q4_0', '-ctv', 'f16',
"-c", f"{context_length}", "-c", f"{context_length}",
*([] if verbose else ["--log-disable"]), *([] if verbose else ["--log-disable"]),
], stdout=sys.stderr) ]
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
atexit.register(server_process.kill) atexit.register(server_process.kill)
endpoint = f"http://{server_host}:{server_port}/completions" endpoint = f"http://{server_host}:{server_port}/completions"
@ -88,7 +90,10 @@ def main(
else: else:
response_schema = None response_schema = None
chat_handler = get_chat_handler(ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools)) chat_handler = get_chat_handler(
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
allow_parallel_calls=allow_parallel_calls
)
messages = chat_request.messages messages = chat_request.messages
if chat_handler.output_format_prompt: if chat_handler.output_format_prompt: