agent: fix response_format

This commit is contained in:
ochafik 2024-04-09 02:14:08 +01:00
parent 6e52a9ce48
commit 701a66d80f
6 changed files with 81 additions and 17 deletions

View file

@ -138,19 +138,28 @@ If you'd like to debug each binary separately (rather than have an agent spawing
```bash ```bash
# C++ server # C++ server
make -j server make -j server
./server --model mixtral.gguf --port 8081 ./server \
--model mixtral.gguf \
--metrics \
-ctk q4_0 \
-ctv f16 \
-c 32768 \
--port 8081
# OpenAI compatibility layer # OpenAI compatibility layer
python -m examples.openai \ python -m examples.openai \
--port 8080 --port 8080 \
--endpoint http://localhost:8081 \ --endpoint http://localhost:8081 \
--template_hf_model_id_fallback mistralai/Mixtral-8x7B-Instruct-v0.1 --template-hf-model-id-fallback mistralai/Mixtral-8x7B-Instruct-v0.1
# Or have the OpenAI compatibility layer spawn the C++ server under the hood: # Or have the OpenAI compatibility layer spawn the C++ server under the hood:
# python -m examples.openai --model mixtral.gguf # python -m examples.openai --model mixtral.gguf
# Agent itself: # Agent itself:
python -m examples.agent --endpoint http://localhost:8080 \ python -m examples.agent --endpoint http://localhost:8080 \
--tools examples/agent/tools/example_summaries.py \
--format PyramidalSummary \
--goal "Create a pyramidal summary of Mankind's recent advancements"
``` ```
## Use existing tools (WIP) ## Use existing tools (WIP)

View file

@ -10,7 +10,7 @@ import json, requests
from examples.json_schema_to_grammar import SchemaConverter from examples.json_schema_to_grammar import SchemaConverter
from examples.agent.tools.std_tools import StandardTools from examples.agent.tools.std_tools import StandardTools
from examples.openai.api import ChatCompletionRequest, ChatCompletionResponse, Message, Tool, ToolFunction from examples.openai.api import ChatCompletionRequest, ChatCompletionResponse, Message, ResponseFormat, Tool, ToolFunction
from examples.agent.utils import collect_functions, load_module from examples.agent.utils import collect_functions, load_module
from examples.openai.prompting import ToolsPromptStyle from examples.openai.prompting import ToolsPromptStyle
@ -46,7 +46,7 @@ def completion_with_tool_usage(
else: else:
type_adapter = TypeAdapter(response_model) type_adapter = TypeAdapter(response_model)
schema = type_adapter.json_schema() schema = type_adapter.json_schema()
response_format={"type": "json_object", "schema": schema } response_format=ResponseFormat(type="json_object", schema=schema)
tool_map = {fn.__name__: fn for fn in tools} tool_map = {fn.__name__: fn for fn in tools}
tools_schemas = [ tools_schemas = [
@ -77,14 +77,15 @@ def completion_with_tool_usage(
if auth: if auth:
headers["Authorization"] = auth headers["Authorization"] = auth
response = requests.post( response = requests.post(
endpoint, f'{endpoint}/v1/chat/completions',
headers=headers, headers=headers,
json=request.model_dump(), json=request.model_dump(),
) )
if response.status_code != 200: if response.status_code != 200:
raise Exception(f"Request failed ({response.status_code}): {response.text}") raise Exception(f"Request failed ({response.status_code}): {response.text}")
response = ChatCompletionResponse(**response.json()) response_json = response.json()
response = ChatCompletionResponse(**response_json)
if verbose: if verbose:
sys.stderr.write(f'# RESPONSE: {response.model_dump_json(indent=2)}\n') sys.stderr.write(f'# RESPONSE: {response.model_dump_json(indent=2)}\n')
if response.error: if response.error:
@ -169,7 +170,7 @@ def main(
if not endpoint: if not endpoint:
server_port = 8080 server_port = 8080
server_host = 'localhost' server_host = 'localhost'
endpoint: str = f'http://{server_host}:{server_port}/v1/chat/completions' endpoint = f'http://{server_host}:{server_port}'
if verbose: if verbose:
sys.stderr.write(f"# Starting C++ server with model {model} on {endpoint}\n") sys.stderr.write(f"# Starting C++ server with model {model} on {endpoint}\n")
cmd = [ cmd = [

View file

@ -28,8 +28,8 @@ class Tool(BaseModel):
function: ToolFunction function: ToolFunction
class ResponseFormat(BaseModel): class ResponseFormat(BaseModel):
type: str type: Literal["json_object"]
json_schema: Optional[Any] = None schema: Optional[Dict] = None
class LlamaCppParams(BaseModel): class LlamaCppParams(BaseModel):
n_predict: Optional[int] = None n_predict: Optional[int] = None

View file

@ -712,15 +712,19 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
else: else:
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}") raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
_ts_converter = SchemaToTypeScriptConverter()
# os.environ.get('NO_TS') # os.environ.get('NO_TS')
def _please_respond_with_schema(schema: dict) -> str: def _please_respond_with_schema(schema: dict) -> str:
# sig = json.dumps(schema, indent=2) # sig = json.dumps(schema, indent=2)
_ts_converter = SchemaToTypeScriptConverter()
_ts_converter.resolve_refs(schema, 'schema')
sig = _ts_converter.visit(schema) sig = _ts_converter.visit(schema)
return f'Please respond in JSON format with the following schema: {sig}' return f'Please respond in JSON format with the following schema: {sig}'
def _tools_typescript_signatures(tools: list[Tool]) -> str: def _tools_typescript_signatures(tools: list[Tool]) -> str:
_ts_converter = SchemaToTypeScriptConverter()
for tool in tools:
_ts_converter.resolve_refs(tool.function.parameters, tool.function.name)
return 'namespace functions {\n' + '\n'.join( return 'namespace functions {\n' + '\n'.join(
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + '' '// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n" 'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"

View file

@ -73,7 +73,7 @@ def main(
] ]
server_process = subprocess.Popen(cmd, stdout=sys.stderr) server_process = subprocess.Popen(cmd, stdout=sys.stderr)
atexit.register(server_process.kill) atexit.register(server_process.kill)
endpoint = f"http://{server_host}:{server_port}/completions" endpoint = f"http://{server_host}:{server_port}"
# print(chat_template.render([ # print(chat_template.render([
@ -125,7 +125,7 @@ def main(
if chat_request.response_format is not None: if chat_request.response_format is not None:
assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}" assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}"
response_schema = chat_request.response_format.json_schema or {} response_schema = chat_request.response_format.schema or {}
else: else:
response_schema = None response_schema = None
@ -164,7 +164,7 @@ def main(
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{endpoint}", f'{endpoint}/completions',
json=data, json=data,
headers=headers, headers=headers,
timeout=None) timeout=None)

View file

@ -14,6 +14,56 @@ class SchemaToTypeScriptConverter:
# // where to get weather. # // where to get weather.
# location: string, # location: string,
# }) => any; # }) => any;
def __init__(self):
self._refs = {}
self._refs_being_resolved = set()
def resolve_refs(self, schema: dict, url: str):
'''
Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries.
'''
def visit(n: dict):
if isinstance(n, list):
return [visit(x) for x in n]
elif isinstance(n, dict):
ref = n.get('$ref')
if ref is not None and ref not in self._refs:
if ref.startswith('https://'):
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
import requests
frag_split = ref.split('#')
base_url = frag_split[0]
target = self._refs.get(base_url)
if target is None:
target = self.resolve_refs(requests.get(ref).json(), base_url)
self._refs[base_url] = target
if len(frag_split) == 1 or frag_split[-1] == '':
return target
elif ref.startswith('#/'):
target = schema
ref = f'{url}{ref}'
n['$ref'] = ref
else:
raise ValueError(f'Unsupported ref {ref}')
for sel in ref.split('#')[-1].split('/')[1:]:
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
self._refs[ref] = target
else:
for v in n.values():
visit(v)
return n
return visit(schema)
def _desc_comment(self, schema: dict): def _desc_comment(self, schema: dict):
desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None
return f'// {desc}\n' if desc else '' return f'// {desc}\n' if desc else ''
@ -78,7 +128,7 @@ class SchemaToTypeScriptConverter:
else: else:
add_component(t, is_required=True) add_component(t, is_required=True)
return self._build_object_rule(properties, required, additional_properties=[]) return self._build_object_rule(properties, required, additional_properties={})
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
items = schema.get('items') or schema['prefixItems'] items = schema.get('items') or schema['prefixItems']
@ -94,4 +144,4 @@ class SchemaToTypeScriptConverter:
return 'any' return 'any'
else: else:
return 'number' if schema_type == 'integer' else schema_type return 'number' if schema_type == 'integer' else schema_type or 'any'