agent: support OpenAI: --endpoint https://api.openai.com --auth "Bearer $OPENAI_API_KEY"

This commit is contained in:
ochafik 2024-05-22 04:11:48 +01:00
parent a39e6e0758
commit 793f4ff3f5
3 changed files with 66 additions and 6 deletions

View file

@ -30,6 +30,7 @@ def completion_with_tool_usage(
messages: List[Message], messages: List[Message],
auth: Optional[str], auth: Optional[str],
verbose: bool, verbose: bool,
assume_llama_cpp_server: bool = False,
**kwargs): **kwargs):
''' '''
Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support
@ -75,7 +76,7 @@ def completion_with_tool_usage(
request = ChatCompletionRequest( request = ChatCompletionRequest(
messages=messages, messages=messages,
response_format=response_format, response_format=response_format,
tools=tools_schemas, tools=tools_schemas if tools_schemas else None,
cache_prompt=True, cache_prompt=True,
**kwargs, **kwargs,
) )
@ -86,10 +87,65 @@ def completion_with_tool_usage(
} }
if auth: if auth:
headers["Authorization"] = auth headers["Authorization"] = auth
def drop_nones(o):
if isinstance(o, BaseModel):
return drop_nones(o.model_dump())
if isinstance(o, list):
return [drop_nones(i) for i in o if i is not None]
if isinstance(o, dict):
return {
k: drop_nones(v)
for k, v in o.items()
if v is not None
}
return o
if assume_llama_cpp_server:
body = request.model_dump()
else:
# request_dict = request.model_dump()
# body = drop_nones(request)
tools_arg = None
tool_choice = request.tool_choice
response_format = None
if request.tools:
tools_arg = drop_nones(request.tools)
if request.response_format:
response_format = {
'type': request.response_format.type,
}
if request.response_format.schema:
assert tools_arg is None
assert tool_choice is None
tools_arg = [{
"type": "function",
"function": {
"name": "output",
"description": "A JSON object",
"parameters": request.response_format.schema,
}
}]
tool_choice = "output"
body = drop_nones(dict(
messages=drop_nones(request.messages),
model=request.model,
tools=tools_arg,
tool_choice=tool_choice,
temperature=request.temperature,
response_format=response_format,
))
if verbose:
sys.stderr.write(f'# POSTing to {endpoint}/v1/chat/completions\n')
sys.stderr.write(f'# HEADERS: {headers}\n')
sys.stderr.write(f'# BODY: {json.dumps(body, indent=2)}\n')
response = requests.post( response = requests.post(
f'{endpoint}/v1/chat/completions', f'{endpoint}/v1/chat/completions',
headers=headers, headers=headers,
json=request.model_dump(), json=body,
) )
response.raise_for_status() response.raise_for_status()
response_json = response.json() response_json = response.json()
@ -143,6 +199,7 @@ def main(
parallel_calls: Optional[bool] = False, parallel_calls: Optional[bool] = False,
verbose: bool = False, verbose: bool = False,
style: Optional[ToolsPromptStyle] = None, style: Optional[ToolsPromptStyle] = None,
assume_llama_cpp_server: Optional[bool] = None,
model: Optional[Annotated[str, typer.Option("--model", "-m")]] = None,# = "models/7B/ggml-model-f16.gguf", model: Optional[Annotated[str, typer.Option("--model", "-m")]] = None,# = "models/7B/ggml-model-f16.gguf",
model_url: Optional[Annotated[str, typer.Option("--model-url", "-mu")]] = None, model_url: Optional[Annotated[str, typer.Option("--model-url", "-mu")]] = None,
@ -184,6 +241,7 @@ def main(
if not endpoint: if not endpoint:
server_port = 8080 server_port = 8080
server_host = 'localhost' server_host = 'localhost'
assume_llama_cpp_server = True
endpoint = f'http://{server_host}:{server_port}' endpoint = f'http://{server_host}:{server_port}'
if verbose: if verbose:
sys.stderr.write(f"# Starting C++ server with model {model} on {endpoint}\n") sys.stderr.write(f"# Starting C++ server with model {model} on {endpoint}\n")
@ -231,13 +289,14 @@ def main(
result = completion_with_tool_usage( result = completion_with_tool_usage(
model="...", model="gpt-4o",
endpoint=endpoint, endpoint=endpoint,
response_model=response_model, response_model=response_model,
max_iterations=max_iterations, max_iterations=max_iterations,
tools=tool_functions, tools=tool_functions,
auth=auth, auth=auth,
verbose=verbose, verbose=verbose,
assume_llama_cpp_server=assume_llama_cpp_server or False,
n_predict=n_predict, n_predict=n_predict,
top_k=top_k, top_k=top_k,

View file

@ -21,8 +21,8 @@ class Message(BaseModel):
class ToolFunction(BaseModel): class ToolFunction(BaseModel):
name: str name: str
description: str
parameters: dict[str, Any] parameters: dict[str, Any]
description: Optional[str] = None
class Tool(BaseModel): class Tool(BaseModel):
type: str type: str
@ -58,6 +58,7 @@ class LlamaCppParams(BaseModel):
class ChatCompletionRequest(LlamaCppParams): class ChatCompletionRequest(LlamaCppParams):
model: str model: str
tools: Optional[List[Tool]] = None tools: Optional[List[Tool]] = None
tool_choice: Optional[str] = None
messages: Optional[List[Message]] = None messages: Optional[List[Message]] = None
prompt: Optional[str] = None prompt: Optional[str] = None
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
@ -87,5 +88,5 @@ class ChatCompletionResponse(BaseModel):
model: str model: str
choices: List[Choice] choices: List[Choice]
usage: Usage usage: Usage
system_fingerprint: str system_fingerprint: Optional[str] = None
error: Optional[CompletionError] = None error: Optional[CompletionError] = None

View file

@ -698,7 +698,7 @@ def _tools_typescript_signatures(tools: list[Tool]) -> str:
# _ts_converter.resolve_refs(tool.function.parameters, tool.function.name) # _ts_converter.resolve_refs(tool.function.parameters, tool.function.name)
return 'namespace functions {\n' + '\n'.join( return 'namespace functions {\n' + '\n'.join(
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + '' '// ' + (tool.function.description or '').replace('\n', '\n// ') + '\n' + ''
'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n" 'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"
for tool in tools for tool in tools
) + '} // namespace functions' ) + '} // namespace functions'