agent: fix response_format

This commit is contained in:
ochafik 2024-04-09 02:14:08 +01:00
parent 6e52a9ce48
commit 701a66d80f
6 changed files with 81 additions and 17 deletions

View file

@ -138,19 +138,28 @@ If you'd like to debug each binary separately (rather than have an agent spawing
```bash
# C++ server
make -j server
./server --model mixtral.gguf --port 8081
./server \
--model mixtral.gguf \
--metrics \
-ctk q4_0 \
-ctv f16 \
-c 32768 \
--port 8081
# OpenAI compatibility layer
python -m examples.openai \
--port 8080
--port 8080 \
--endpoint http://localhost:8081 \
--template_hf_model_id_fallback mistralai/Mixtral-8x7B-Instruct-v0.1
--template-hf-model-id-fallback mistralai/Mixtral-8x7B-Instruct-v0.1
# Or have the OpenAI compatibility layer spawn the C++ server under the hood:
# python -m examples.openai --model mixtral.gguf
# Agent itself:
python -m examples.agent --endpoint http://localhost:8080 \
--tools examples/agent/tools/example_summaries.py \
--format PyramidalSummary \
--goal "Create a pyramidal summary of Mankind's recent advancements"
```
## Use existing tools (WIP)

View file

@ -10,7 +10,7 @@ import json, requests
from examples.json_schema_to_grammar import SchemaConverter
from examples.agent.tools.std_tools import StandardTools
from examples.openai.api import ChatCompletionRequest, ChatCompletionResponse, Message, Tool, ToolFunction
from examples.openai.api import ChatCompletionRequest, ChatCompletionResponse, Message, ResponseFormat, Tool, ToolFunction
from examples.agent.utils import collect_functions, load_module
from examples.openai.prompting import ToolsPromptStyle
@ -46,7 +46,7 @@ def completion_with_tool_usage(
else:
type_adapter = TypeAdapter(response_model)
schema = type_adapter.json_schema()
response_format={"type": "json_object", "schema": schema }
response_format=ResponseFormat(type="json_object", schema=schema)
tool_map = {fn.__name__: fn for fn in tools}
tools_schemas = [
@ -77,14 +77,15 @@ def completion_with_tool_usage(
if auth:
headers["Authorization"] = auth
response = requests.post(
endpoint,
f'{endpoint}/v1/chat/completions',
headers=headers,
json=request.model_dump(),
)
if response.status_code != 200:
raise Exception(f"Request failed ({response.status_code}): {response.text}")
response = ChatCompletionResponse(**response.json())
response_json = response.json()
response = ChatCompletionResponse(**response_json)
if verbose:
sys.stderr.write(f'# RESPONSE: {response.model_dump_json(indent=2)}\n')
if response.error:
@ -169,7 +170,7 @@ def main(
if not endpoint:
server_port = 8080
server_host = 'localhost'
endpoint: str = f'http://{server_host}:{server_port}/v1/chat/completions'
endpoint = f'http://{server_host}:{server_port}'
if verbose:
sys.stderr.write(f"# Starting C++ server with model {model} on {endpoint}\n")
cmd = [

View file

@ -28,8 +28,8 @@ class Tool(BaseModel):
function: ToolFunction
class ResponseFormat(BaseModel):
type: str
json_schema: Optional[Any] = None
type: Literal["json_object"]
schema: Optional[Dict] = None
class LlamaCppParams(BaseModel):
n_predict: Optional[int] = None

View file

@ -712,15 +712,19 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
else:
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
_ts_converter = SchemaToTypeScriptConverter()
# os.environ.get('NO_TS')
def _please_respond_with_schema(schema: dict) -> str:
# sig = json.dumps(schema, indent=2)
_ts_converter = SchemaToTypeScriptConverter()
_ts_converter.resolve_refs(schema, 'schema')
sig = _ts_converter.visit(schema)
return f'Please respond in JSON format with the following schema: {sig}'
def _tools_typescript_signatures(tools: list[Tool]) -> str:
_ts_converter = SchemaToTypeScriptConverter()
for tool in tools:
_ts_converter.resolve_refs(tool.function.parameters, tool.function.name)
return 'namespace functions {\n' + '\n'.join(
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"

View file

@ -73,7 +73,7 @@ def main(
]
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
atexit.register(server_process.kill)
endpoint = f"http://{server_host}:{server_port}/completions"
endpoint = f"http://{server_host}:{server_port}"
# print(chat_template.render([
@ -125,7 +125,7 @@ def main(
if chat_request.response_format is not None:
assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}"
response_schema = chat_request.response_format.json_schema or {}
response_schema = chat_request.response_format.schema or {}
else:
response_schema = None
@ -164,7 +164,7 @@ def main(
async with httpx.AsyncClient() as client:
response = await client.post(
f"{endpoint}",
f'{endpoint}/completions',
json=data,
headers=headers,
timeout=None)

View file

@ -14,6 +14,56 @@ class SchemaToTypeScriptConverter:
# // where to get weather.
# location: string,
# }) => any;
def __init__(self):
self._refs = {}
self._refs_being_resolved = set()
def resolve_refs(self, schema: dict, url: str):
'''
Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries.
'''
def visit(n: dict):
if isinstance(n, list):
return [visit(x) for x in n]
elif isinstance(n, dict):
ref = n.get('$ref')
if ref is not None and ref not in self._refs:
if ref.startswith('https://'):
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
import requests
frag_split = ref.split('#')
base_url = frag_split[0]
target = self._refs.get(base_url)
if target is None:
target = self.resolve_refs(requests.get(ref).json(), base_url)
self._refs[base_url] = target
if len(frag_split) == 1 or frag_split[-1] == '':
return target
elif ref.startswith('#/'):
target = schema
ref = f'{url}{ref}'
n['$ref'] = ref
else:
raise ValueError(f'Unsupported ref {ref}')
for sel in ref.split('#')[-1].split('/')[1:]:
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
self._refs[ref] = target
else:
for v in n.values():
visit(v)
return n
return visit(schema)
def _desc_comment(self, schema: dict):
desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None
return f'// {desc}\n' if desc else ''
@ -78,7 +128,7 @@ class SchemaToTypeScriptConverter:
else:
add_component(t, is_required=True)
return self._build_object_rule(properties, required, additional_properties=[])
return self._build_object_rule(properties, required, additional_properties={})
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
items = schema.get('items') or schema['prefixItems']
@ -94,4 +144,4 @@ class SchemaToTypeScriptConverter:
return 'any'
else:
return 'number' if schema_type == 'integer' else schema_type
return 'number' if schema_type == 'integer' else schema_type or 'any'