agent: support OpenAI: --endpoint https://api.openai.com --auth "Bearer $OPENAI_API_KEY"
This commit is contained in:
parent
a39e6e0758
commit
793f4ff3f5
3 changed files with 66 additions and 6 deletions
|
@ -30,6 +30,7 @@ def completion_with_tool_usage(
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
auth: Optional[str],
|
auth: Optional[str],
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
|
assume_llama_cpp_server: bool = False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
'''
|
'''
|
||||||
Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support
|
Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support
|
||||||
|
@ -75,7 +76,7 @@ def completion_with_tool_usage(
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
tools=tools_schemas,
|
tools=tools_schemas if tools_schemas else None,
|
||||||
cache_prompt=True,
|
cache_prompt=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -86,10 +87,65 @@ def completion_with_tool_usage(
|
||||||
}
|
}
|
||||||
if auth:
|
if auth:
|
||||||
headers["Authorization"] = auth
|
headers["Authorization"] = auth
|
||||||
|
|
||||||
|
def drop_nones(o):
|
||||||
|
if isinstance(o, BaseModel):
|
||||||
|
return drop_nones(o.model_dump())
|
||||||
|
if isinstance(o, list):
|
||||||
|
return [drop_nones(i) for i in o if i is not None]
|
||||||
|
if isinstance(o, dict):
|
||||||
|
return {
|
||||||
|
k: drop_nones(v)
|
||||||
|
for k, v in o.items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
return o
|
||||||
|
|
||||||
|
if assume_llama_cpp_server:
|
||||||
|
body = request.model_dump()
|
||||||
|
else:
|
||||||
|
# request_dict = request.model_dump()
|
||||||
|
# body = drop_nones(request)
|
||||||
|
tools_arg = None
|
||||||
|
tool_choice = request.tool_choice
|
||||||
|
response_format = None
|
||||||
|
if request.tools:
|
||||||
|
tools_arg = drop_nones(request.tools)
|
||||||
|
if request.response_format:
|
||||||
|
response_format = {
|
||||||
|
'type': request.response_format.type,
|
||||||
|
}
|
||||||
|
if request.response_format.schema:
|
||||||
|
assert tools_arg is None
|
||||||
|
assert tool_choice is None
|
||||||
|
tools_arg = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "output",
|
||||||
|
"description": "A JSON object",
|
||||||
|
"parameters": request.response_format.schema,
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
tool_choice = "output"
|
||||||
|
|
||||||
|
body = drop_nones(dict(
|
||||||
|
messages=drop_nones(request.messages),
|
||||||
|
model=request.model,
|
||||||
|
tools=tools_arg,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
temperature=request.temperature,
|
||||||
|
response_format=response_format,
|
||||||
|
))
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
sys.stderr.write(f'# POSTing to {endpoint}/v1/chat/completions\n')
|
||||||
|
sys.stderr.write(f'# HEADERS: {headers}\n')
|
||||||
|
sys.stderr.write(f'# BODY: {json.dumps(body, indent=2)}\n')
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f'{endpoint}/v1/chat/completions',
|
f'{endpoint}/v1/chat/completions',
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=request.model_dump(),
|
json=body,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
@ -143,6 +199,7 @@ def main(
|
||||||
parallel_calls: Optional[bool] = False,
|
parallel_calls: Optional[bool] = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
style: Optional[ToolsPromptStyle] = None,
|
style: Optional[ToolsPromptStyle] = None,
|
||||||
|
assume_llama_cpp_server: Optional[bool] = None,
|
||||||
|
|
||||||
model: Optional[Annotated[str, typer.Option("--model", "-m")]] = None,# = "models/7B/ggml-model-f16.gguf",
|
model: Optional[Annotated[str, typer.Option("--model", "-m")]] = None,# = "models/7B/ggml-model-f16.gguf",
|
||||||
model_url: Optional[Annotated[str, typer.Option("--model-url", "-mu")]] = None,
|
model_url: Optional[Annotated[str, typer.Option("--model-url", "-mu")]] = None,
|
||||||
|
@ -184,6 +241,7 @@ def main(
|
||||||
if not endpoint:
|
if not endpoint:
|
||||||
server_port = 8080
|
server_port = 8080
|
||||||
server_host = 'localhost'
|
server_host = 'localhost'
|
||||||
|
assume_llama_cpp_server = True
|
||||||
endpoint = f'http://{server_host}:{server_port}'
|
endpoint = f'http://{server_host}:{server_port}'
|
||||||
if verbose:
|
if verbose:
|
||||||
sys.stderr.write(f"# Starting C++ server with model {model} on {endpoint}\n")
|
sys.stderr.write(f"# Starting C++ server with model {model} on {endpoint}\n")
|
||||||
|
@ -231,13 +289,14 @@ def main(
|
||||||
|
|
||||||
|
|
||||||
result = completion_with_tool_usage(
|
result = completion_with_tool_usage(
|
||||||
model="...",
|
model="gpt-4o",
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
max_iterations=max_iterations,
|
max_iterations=max_iterations,
|
||||||
tools=tool_functions,
|
tools=tool_functions,
|
||||||
auth=auth,
|
auth=auth,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
assume_llama_cpp_server=assume_llama_cpp_server or False,
|
||||||
|
|
||||||
n_predict=n_predict,
|
n_predict=n_predict,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
|
|
@ -21,8 +21,8 @@ class Message(BaseModel):
|
||||||
|
|
||||||
class ToolFunction(BaseModel):
|
class ToolFunction(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
|
||||||
parameters: dict[str, Any]
|
parameters: dict[str, Any]
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
class Tool(BaseModel):
|
class Tool(BaseModel):
|
||||||
type: str
|
type: str
|
||||||
|
@ -58,6 +58,7 @@ class LlamaCppParams(BaseModel):
|
||||||
class ChatCompletionRequest(LlamaCppParams):
|
class ChatCompletionRequest(LlamaCppParams):
|
||||||
model: str
|
model: str
|
||||||
tools: Optional[List[Tool]] = None
|
tools: Optional[List[Tool]] = None
|
||||||
|
tool_choice: Optional[str] = None
|
||||||
messages: Optional[List[Message]] = None
|
messages: Optional[List[Message]] = None
|
||||||
prompt: Optional[str] = None
|
prompt: Optional[str] = None
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
@ -87,5 +88,5 @@ class ChatCompletionResponse(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
choices: List[Choice]
|
choices: List[Choice]
|
||||||
usage: Usage
|
usage: Usage
|
||||||
system_fingerprint: str
|
system_fingerprint: Optional[str] = None
|
||||||
error: Optional[CompletionError] = None
|
error: Optional[CompletionError] = None
|
||||||
|
|
|
@ -698,7 +698,7 @@ def _tools_typescript_signatures(tools: list[Tool]) -> str:
|
||||||
# _ts_converter.resolve_refs(tool.function.parameters, tool.function.name)
|
# _ts_converter.resolve_refs(tool.function.parameters, tool.function.name)
|
||||||
|
|
||||||
return 'namespace functions {\n' + '\n'.join(
|
return 'namespace functions {\n' + '\n'.join(
|
||||||
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
|
'// ' + (tool.function.description or '').replace('\n', '\n// ') + '\n' + ''
|
||||||
'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"
|
'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"
|
||||||
for tool in tools
|
for tool in tools
|
||||||
) + '} // namespace functions'
|
) + '} // namespace functions'
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue