agent: add --allow_parallel_calls
This commit is contained in:
parent
c340e8cd3b
commit
ce2fb0155f
4 changed files with 29 additions and 12 deletions
|
@ -157,6 +157,8 @@ REQUIREMENTS_FILE=<( cat examples/agents/hermes_function_calling/requirements.tx
|
|||
|
||||
## TODO
|
||||
|
||||
- Wait for spawned servers to be heathly
|
||||
|
||||
- Add model URL / HF loading support
|
||||
|
||||
- Add Embedding endpoint + storage / retrieval tools (Faiss? ScaNN?), or spontaneous RAG
|
||||
|
|
|
@ -128,6 +128,7 @@ def main(
|
|||
max_iterations: Optional[int] = 10,
|
||||
std_tools: Optional[bool] = False,
|
||||
auth: Optional[str] = None,
|
||||
allow_parallel_calls: Optional[bool] = False,
|
||||
verbose: bool = False,
|
||||
|
||||
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||
|
@ -135,6 +136,8 @@ def main(
|
|||
context_length: Optional[int] = None,
|
||||
# endpoint: str = 'http://localhost:8080/v1/chat/completions',
|
||||
|
||||
greedy: Optional[bool] = True,
|
||||
|
||||
n_predict: Optional[int] = 1000,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
|
@ -157,6 +160,10 @@ def main(
|
|||
n_probs: Optional[int] = None,
|
||||
min_keep: Optional[int] = None,
|
||||
):
|
||||
if greedy:
|
||||
top_k = 1
|
||||
top_p = 0.0
|
||||
|
||||
if not endpoint:
|
||||
server_port = 8080
|
||||
server_host = 'localhost'
|
||||
|
@ -167,9 +174,10 @@ def main(
|
|||
"python", "-m", "examples.openai.server",
|
||||
"--model", model,
|
||||
*(['--verbose'] if verbose else []),
|
||||
*([f'--context_length={context_length}'] if context_length else []),
|
||||
*(['--allow-parallel-calls'] if allow_parallel_calls else []),
|
||||
*(['--context-length={context_length}'] if context_length else []),
|
||||
*([])
|
||||
]
|
||||
print(cmd)
|
||||
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
|
||||
atexit.register(server_process.kill)
|
||||
sleep(5)
|
||||
|
|
|
@ -320,8 +320,8 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
|||
)
|
||||
|
||||
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs):
|
||||
super().__init__(args, escapes_underscores=False, allow_parallel_calls=False)
|
||||
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
|
||||
super().__init__(args, escapes_underscores=False, allow_parallel_calls=allow_parallel_calls)
|
||||
|
||||
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
|
||||
path = str(Path(__file__).parent / "hermes_function_calling")
|
||||
|
@ -433,7 +433,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
|||
content = '\n'.join(text_content).strip()
|
||||
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
|
||||
|
||||
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls=False):
|
||||
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls):
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -474,7 +474,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
|
|||
}
|
||||
|
||||
class BespokeToolsChatHandler(ChatHandler):
|
||||
def __init__(self, args: ChatHandlerArgs):
|
||||
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
|
||||
super().__init__(args)
|
||||
|
||||
# args.response_schema = args.response_schema or {}
|
||||
|
@ -496,7 +496,8 @@ class BespokeToolsChatHandler(ChatHandler):
|
|||
}
|
||||
for tool in self.args.tools
|
||||
]
|
||||
}
|
||||
},
|
||||
allow_parallel_calls=allow_parallel_calls,
|
||||
),
|
||||
'',
|
||||
)
|
||||
|
@ -523,7 +524,8 @@ class BespokeToolsChatHandler(ChatHandler):
|
|||
}
|
||||
},
|
||||
"required": ["name", "arguments"]
|
||||
}
|
||||
},
|
||||
allow_parallel_calls=allow_parallel_calls,
|
||||
)
|
||||
),
|
||||
])
|
||||
|
@ -589,7 +591,7 @@ def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatH
|
|||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
|
||||
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
|
||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
|
||||
return BespokeToolsChatHandler(args)
|
||||
return BespokeToolsChatHandler(args, allow_parallel_calls=allow_parallel_calls)
|
||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
||||
return Hermes2ProToolsChatHandler(args)
|
||||
else:
|
||||
|
|
|
@ -31,6 +31,7 @@ def main(
|
|||
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
|
||||
host: str = "localhost",
|
||||
port: int = 8080,
|
||||
allow_parallel_calls: Optional[bool] = False,
|
||||
auth: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
context_length: Optional[int] = None,
|
||||
|
@ -61,14 +62,15 @@ def main(
|
|||
|
||||
if verbose:
|
||||
sys.stderr.write(f"# Starting C++ server with model {model} on {server_host}:{server_port}\n")
|
||||
server_process = subprocess.Popen([
|
||||
cmd = [
|
||||
"./server", "-m", model,
|
||||
"--host", server_host, "--port", f'{server_port}',
|
||||
# TODO: pass these from JSON / BaseSettings?
|
||||
'-ctk', 'q4_0', '-ctv', 'f16',
|
||||
"-c", f"{context_length}",
|
||||
*([] if verbose else ["--log-disable"]),
|
||||
], stdout=sys.stderr)
|
||||
]
|
||||
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
|
||||
atexit.register(server_process.kill)
|
||||
endpoint = f"http://{server_host}:{server_port}/completions"
|
||||
|
||||
|
@ -88,7 +90,10 @@ def main(
|
|||
else:
|
||||
response_schema = None
|
||||
|
||||
chat_handler = get_chat_handler(ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools))
|
||||
chat_handler = get_chat_handler(
|
||||
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
|
||||
allow_parallel_calls=allow_parallel_calls
|
||||
)
|
||||
|
||||
messages = chat_request.messages
|
||||
if chat_handler.output_format_prompt:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue