agent: add --allow_parallel_calls

This commit is contained in:
ochafik 2024-03-29 16:40:23 +00:00
parent c340e8cd3b
commit ce2fb0155f
4 changed files with 29 additions and 12 deletions

View file

@ -157,6 +157,8 @@ REQUIREMENTS_FILE=<( cat examples/agents/hermes_function_calling/requirements.tx
## TODO
- Wait for spawned servers to be heathly
- Add model URL / HF loading support
- Add Embedding endpoint + storage / retrieval tools (Faiss? ScaNN?), or spontaneous RAG

View file

@ -128,6 +128,7 @@ def main(
max_iterations: Optional[int] = 10,
std_tools: Optional[bool] = False,
auth: Optional[str] = None,
allow_parallel_calls: Optional[bool] = False,
verbose: bool = False,
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
@ -135,6 +136,8 @@ def main(
context_length: Optional[int] = None,
# endpoint: str = 'http://localhost:8080/v1/chat/completions',
greedy: Optional[bool] = True,
n_predict: Optional[int] = 1000,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
@ -157,6 +160,10 @@ def main(
n_probs: Optional[int] = None,
min_keep: Optional[int] = None,
):
if greedy:
top_k = 1
top_p = 0.0
if not endpoint:
server_port = 8080
server_host = 'localhost'
@ -167,9 +174,10 @@ def main(
"python", "-m", "examples.openai.server",
"--model", model,
*(['--verbose'] if verbose else []),
*([f'--context_length={context_length}'] if context_length else []),
*(['--allow-parallel-calls'] if allow_parallel_calls else []),
*(['--context-length={context_length}'] if context_length else []),
*([])
]
print(cmd)
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
atexit.register(server_process.kill)
sleep(5)

View file

@ -320,8 +320,8 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
)
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs):
super().__init__(args, escapes_underscores=False, allow_parallel_calls=False)
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
super().__init__(args, escapes_underscores=False, allow_parallel_calls=allow_parallel_calls)
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
path = str(Path(__file__).parent / "hermes_function_calling")
@ -433,7 +433,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
content = '\n'.join(text_content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls=False):
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls):
return {
"type": "object",
"properties": {
@ -474,7 +474,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
}
class BespokeToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
super().__init__(args)
# args.response_schema = args.response_schema or {}
@ -496,7 +496,8 @@ class BespokeToolsChatHandler(ChatHandler):
}
for tool in self.args.tools
]
}
},
allow_parallel_calls=allow_parallel_calls,
),
'',
)
@ -523,7 +524,8 @@ class BespokeToolsChatHandler(ChatHandler):
}
},
"required": ["name", "arguments"]
}
},
allow_parallel_calls=allow_parallel_calls,
)
),
])
@ -589,7 +591,7 @@ def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatH
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
return BespokeToolsChatHandler(args)
return BespokeToolsChatHandler(args, allow_parallel_calls=allow_parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
return Hermes2ProToolsChatHandler(args)
else:

View file

@ -31,6 +31,7 @@ def main(
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
host: str = "localhost",
port: int = 8080,
allow_parallel_calls: Optional[bool] = False,
auth: Optional[str] = None,
verbose: bool = False,
context_length: Optional[int] = None,
@ -61,14 +62,15 @@ def main(
if verbose:
sys.stderr.write(f"# Starting C++ server with model {model} on {server_host}:{server_port}\n")
server_process = subprocess.Popen([
cmd = [
"./server", "-m", model,
"--host", server_host, "--port", f'{server_port}',
# TODO: pass these from JSON / BaseSettings?
'-ctk', 'q4_0', '-ctv', 'f16',
"-c", f"{context_length}",
*([] if verbose else ["--log-disable"]),
], stdout=sys.stderr)
]
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
atexit.register(server_process.kill)
endpoint = f"http://{server_host}:{server_port}/completions"
@ -88,7 +90,10 @@ def main(
else:
response_schema = None
chat_handler = get_chat_handler(ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools))
chat_handler = get_chat_handler(
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
allow_parallel_calls=allow_parallel_calls
)
messages = chat_request.messages
if chat_handler.output_format_prompt: