agent: add --allow_parallel_calls
This commit is contained in:
parent
c340e8cd3b
commit
ce2fb0155f
4 changed files with 29 additions and 12 deletions
|
@ -157,6 +157,8 @@ REQUIREMENTS_FILE=<( cat examples/agents/hermes_function_calling/requirements.tx
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
|
- Wait for spawned servers to be heathly
|
||||||
|
|
||||||
- Add model URL / HF loading support
|
- Add model URL / HF loading support
|
||||||
|
|
||||||
- Add Embedding endpoint + storage / retrieval tools (Faiss? ScaNN?), or spontaneous RAG
|
- Add Embedding endpoint + storage / retrieval tools (Faiss? ScaNN?), or spontaneous RAG
|
||||||
|
|
|
@ -128,6 +128,7 @@ def main(
|
||||||
max_iterations: Optional[int] = 10,
|
max_iterations: Optional[int] = 10,
|
||||||
std_tools: Optional[bool] = False,
|
std_tools: Optional[bool] = False,
|
||||||
auth: Optional[str] = None,
|
auth: Optional[str] = None,
|
||||||
|
allow_parallel_calls: Optional[bool] = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
|
||||||
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||||
|
@ -135,6 +136,8 @@ def main(
|
||||||
context_length: Optional[int] = None,
|
context_length: Optional[int] = None,
|
||||||
# endpoint: str = 'http://localhost:8080/v1/chat/completions',
|
# endpoint: str = 'http://localhost:8080/v1/chat/completions',
|
||||||
|
|
||||||
|
greedy: Optional[bool] = True,
|
||||||
|
|
||||||
n_predict: Optional[int] = 1000,
|
n_predict: Optional[int] = 1000,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
|
@ -157,6 +160,10 @@ def main(
|
||||||
n_probs: Optional[int] = None,
|
n_probs: Optional[int] = None,
|
||||||
min_keep: Optional[int] = None,
|
min_keep: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
if greedy:
|
||||||
|
top_k = 1
|
||||||
|
top_p = 0.0
|
||||||
|
|
||||||
if not endpoint:
|
if not endpoint:
|
||||||
server_port = 8080
|
server_port = 8080
|
||||||
server_host = 'localhost'
|
server_host = 'localhost'
|
||||||
|
@ -167,9 +174,10 @@ def main(
|
||||||
"python", "-m", "examples.openai.server",
|
"python", "-m", "examples.openai.server",
|
||||||
"--model", model,
|
"--model", model,
|
||||||
*(['--verbose'] if verbose else []),
|
*(['--verbose'] if verbose else []),
|
||||||
*([f'--context_length={context_length}'] if context_length else []),
|
*(['--allow-parallel-calls'] if allow_parallel_calls else []),
|
||||||
|
*(['--context-length={context_length}'] if context_length else []),
|
||||||
|
*([])
|
||||||
]
|
]
|
||||||
print(cmd)
|
|
||||||
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
|
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
|
||||||
atexit.register(server_process.kill)
|
atexit.register(server_process.kill)
|
||||||
sleep(5)
|
sleep(5)
|
||||||
|
|
|
@ -320,8 +320,8 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs):
|
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
|
||||||
super().__init__(args, escapes_underscores=False, allow_parallel_calls=False)
|
super().__init__(args, escapes_underscores=False, allow_parallel_calls=allow_parallel_calls)
|
||||||
|
|
||||||
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
|
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
|
||||||
path = str(Path(__file__).parent / "hermes_function_calling")
|
path = str(Path(__file__).parent / "hermes_function_calling")
|
||||||
|
@ -433,7 +433,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
||||||
content = '\n'.join(text_content).strip()
|
content = '\n'.join(text_content).strip()
|
||||||
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
|
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
|
||||||
|
|
||||||
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls=False):
|
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls):
|
||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -474,7 +474,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
|
||||||
}
|
}
|
||||||
|
|
||||||
class BespokeToolsChatHandler(ChatHandler):
|
class BespokeToolsChatHandler(ChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs):
|
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
|
|
||||||
# args.response_schema = args.response_schema or {}
|
# args.response_schema = args.response_schema or {}
|
||||||
|
@ -496,7 +496,8 @@ class BespokeToolsChatHandler(ChatHandler):
|
||||||
}
|
}
|
||||||
for tool in self.args.tools
|
for tool in self.args.tools
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
allow_parallel_calls=allow_parallel_calls,
|
||||||
),
|
),
|
||||||
'',
|
'',
|
||||||
)
|
)
|
||||||
|
@ -523,7 +524,8 @@ class BespokeToolsChatHandler(ChatHandler):
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["name", "arguments"]
|
"required": ["name", "arguments"]
|
||||||
}
|
},
|
||||||
|
allow_parallel_calls=allow_parallel_calls,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
|
@ -589,7 +591,7 @@ def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatH
|
||||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
|
||||||
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
|
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
|
||||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
|
||||||
return BespokeToolsChatHandler(args)
|
return BespokeToolsChatHandler(args, allow_parallel_calls=allow_parallel_calls)
|
||||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
||||||
return Hermes2ProToolsChatHandler(args)
|
return Hermes2ProToolsChatHandler(args)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -31,6 +31,7 @@ def main(
|
||||||
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
|
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
|
||||||
host: str = "localhost",
|
host: str = "localhost",
|
||||||
port: int = 8080,
|
port: int = 8080,
|
||||||
|
allow_parallel_calls: Optional[bool] = False,
|
||||||
auth: Optional[str] = None,
|
auth: Optional[str] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
context_length: Optional[int] = None,
|
context_length: Optional[int] = None,
|
||||||
|
@ -61,14 +62,15 @@ def main(
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
sys.stderr.write(f"# Starting C++ server with model {model} on {server_host}:{server_port}\n")
|
sys.stderr.write(f"# Starting C++ server with model {model} on {server_host}:{server_port}\n")
|
||||||
server_process = subprocess.Popen([
|
cmd = [
|
||||||
"./server", "-m", model,
|
"./server", "-m", model,
|
||||||
"--host", server_host, "--port", f'{server_port}',
|
"--host", server_host, "--port", f'{server_port}',
|
||||||
# TODO: pass these from JSON / BaseSettings?
|
# TODO: pass these from JSON / BaseSettings?
|
||||||
'-ctk', 'q4_0', '-ctv', 'f16',
|
'-ctk', 'q4_0', '-ctv', 'f16',
|
||||||
"-c", f"{context_length}",
|
"-c", f"{context_length}",
|
||||||
*([] if verbose else ["--log-disable"]),
|
*([] if verbose else ["--log-disable"]),
|
||||||
], stdout=sys.stderr)
|
]
|
||||||
|
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
|
||||||
atexit.register(server_process.kill)
|
atexit.register(server_process.kill)
|
||||||
endpoint = f"http://{server_host}:{server_port}/completions"
|
endpoint = f"http://{server_host}:{server_port}/completions"
|
||||||
|
|
||||||
|
@ -88,7 +90,10 @@ def main(
|
||||||
else:
|
else:
|
||||||
response_schema = None
|
response_schema = None
|
||||||
|
|
||||||
chat_handler = get_chat_handler(ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools))
|
chat_handler = get_chat_handler(
|
||||||
|
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
|
||||||
|
allow_parallel_calls=allow_parallel_calls
|
||||||
|
)
|
||||||
|
|
||||||
messages = chat_request.messages
|
messages = chat_request.messages
|
||||||
if chat_handler.output_format_prompt:
|
if chat_handler.output_format_prompt:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue