agent: fix response_format
This commit is contained in:
parent
6e52a9ce48
commit
701a66d80f
6 changed files with 81 additions and 17 deletions
|
@ -138,19 +138,28 @@ If you'd like to debug each binary separately (rather than have an agent spawing
|
||||||
```bash
|
```bash
|
||||||
# C++ server
|
# C++ server
|
||||||
make -j server
|
make -j server
|
||||||
./server --model mixtral.gguf --port 8081
|
./server \
|
||||||
|
--model mixtral.gguf \
|
||||||
|
--metrics \
|
||||||
|
-ctk q4_0 \
|
||||||
|
-ctv f16 \
|
||||||
|
-c 32768 \
|
||||||
|
--port 8081
|
||||||
|
|
||||||
# OpenAI compatibility layer
|
# OpenAI compatibility layer
|
||||||
python -m examples.openai \
|
python -m examples.openai \
|
||||||
--port 8080
|
--port 8080 \
|
||||||
--endpoint http://localhost:8081 \
|
--endpoint http://localhost:8081 \
|
||||||
--template_hf_model_id_fallback mistralai/Mixtral-8x7B-Instruct-v0.1
|
--template-hf-model-id-fallback mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||||
|
|
||||||
# Or have the OpenAI compatibility layer spawn the C++ server under the hood:
|
# Or have the OpenAI compatibility layer spawn the C++ server under the hood:
|
||||||
# python -m examples.openai --model mixtral.gguf
|
# python -m examples.openai --model mixtral.gguf
|
||||||
|
|
||||||
# Agent itself:
|
# Agent itself:
|
||||||
python -m examples.agent --endpoint http://localhost:8080 \
|
python -m examples.agent --endpoint http://localhost:8080 \
|
||||||
|
--tools examples/agent/tools/example_summaries.py \
|
||||||
|
--format PyramidalSummary \
|
||||||
|
--goal "Create a pyramidal summary of Mankind's recent advancements"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Use existing tools (WIP)
|
## Use existing tools (WIP)
|
||||||
|
|
|
@ -10,7 +10,7 @@ import json, requests
|
||||||
|
|
||||||
from examples.json_schema_to_grammar import SchemaConverter
|
from examples.json_schema_to_grammar import SchemaConverter
|
||||||
from examples.agent.tools.std_tools import StandardTools
|
from examples.agent.tools.std_tools import StandardTools
|
||||||
from examples.openai.api import ChatCompletionRequest, ChatCompletionResponse, Message, Tool, ToolFunction
|
from examples.openai.api import ChatCompletionRequest, ChatCompletionResponse, Message, ResponseFormat, Tool, ToolFunction
|
||||||
from examples.agent.utils import collect_functions, load_module
|
from examples.agent.utils import collect_functions, load_module
|
||||||
from examples.openai.prompting import ToolsPromptStyle
|
from examples.openai.prompting import ToolsPromptStyle
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ def completion_with_tool_usage(
|
||||||
else:
|
else:
|
||||||
type_adapter = TypeAdapter(response_model)
|
type_adapter = TypeAdapter(response_model)
|
||||||
schema = type_adapter.json_schema()
|
schema = type_adapter.json_schema()
|
||||||
response_format={"type": "json_object", "schema": schema }
|
response_format=ResponseFormat(type="json_object", schema=schema)
|
||||||
|
|
||||||
tool_map = {fn.__name__: fn for fn in tools}
|
tool_map = {fn.__name__: fn for fn in tools}
|
||||||
tools_schemas = [
|
tools_schemas = [
|
||||||
|
@ -77,14 +77,15 @@ def completion_with_tool_usage(
|
||||||
if auth:
|
if auth:
|
||||||
headers["Authorization"] = auth
|
headers["Authorization"] = auth
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
endpoint,
|
f'{endpoint}/v1/chat/completions',
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=request.model_dump(),
|
json=request.model_dump(),
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise Exception(f"Request failed ({response.status_code}): {response.text}")
|
raise Exception(f"Request failed ({response.status_code}): {response.text}")
|
||||||
|
|
||||||
response = ChatCompletionResponse(**response.json())
|
response_json = response.json()
|
||||||
|
response = ChatCompletionResponse(**response_json)
|
||||||
if verbose:
|
if verbose:
|
||||||
sys.stderr.write(f'# RESPONSE: {response.model_dump_json(indent=2)}\n')
|
sys.stderr.write(f'# RESPONSE: {response.model_dump_json(indent=2)}\n')
|
||||||
if response.error:
|
if response.error:
|
||||||
|
@ -169,7 +170,7 @@ def main(
|
||||||
if not endpoint:
|
if not endpoint:
|
||||||
server_port = 8080
|
server_port = 8080
|
||||||
server_host = 'localhost'
|
server_host = 'localhost'
|
||||||
endpoint: str = f'http://{server_host}:{server_port}/v1/chat/completions'
|
endpoint = f'http://{server_host}:{server_port}'
|
||||||
if verbose:
|
if verbose:
|
||||||
sys.stderr.write(f"# Starting C++ server with model {model} on {endpoint}\n")
|
sys.stderr.write(f"# Starting C++ server with model {model} on {endpoint}\n")
|
||||||
cmd = [
|
cmd = [
|
||||||
|
|
|
@ -28,8 +28,8 @@ class Tool(BaseModel):
|
||||||
function: ToolFunction
|
function: ToolFunction
|
||||||
|
|
||||||
class ResponseFormat(BaseModel):
|
class ResponseFormat(BaseModel):
|
||||||
type: str
|
type: Literal["json_object"]
|
||||||
json_schema: Optional[Any] = None
|
schema: Optional[Dict] = None
|
||||||
|
|
||||||
class LlamaCppParams(BaseModel):
|
class LlamaCppParams(BaseModel):
|
||||||
n_predict: Optional[int] = None
|
n_predict: Optional[int] = None
|
||||||
|
|
|
@ -712,15 +712,19 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
|
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
|
||||||
|
|
||||||
_ts_converter = SchemaToTypeScriptConverter()
|
|
||||||
|
|
||||||
# os.environ.get('NO_TS')
|
# os.environ.get('NO_TS')
|
||||||
def _please_respond_with_schema(schema: dict) -> str:
|
def _please_respond_with_schema(schema: dict) -> str:
|
||||||
# sig = json.dumps(schema, indent=2)
|
# sig = json.dumps(schema, indent=2)
|
||||||
|
_ts_converter = SchemaToTypeScriptConverter()
|
||||||
|
_ts_converter.resolve_refs(schema, 'schema')
|
||||||
sig = _ts_converter.visit(schema)
|
sig = _ts_converter.visit(schema)
|
||||||
return f'Please respond in JSON format with the following schema: {sig}'
|
return f'Please respond in JSON format with the following schema: {sig}'
|
||||||
|
|
||||||
def _tools_typescript_signatures(tools: list[Tool]) -> str:
|
def _tools_typescript_signatures(tools: list[Tool]) -> str:
|
||||||
|
_ts_converter = SchemaToTypeScriptConverter()
|
||||||
|
for tool in tools:
|
||||||
|
_ts_converter.resolve_refs(tool.function.parameters, tool.function.name)
|
||||||
|
|
||||||
return 'namespace functions {\n' + '\n'.join(
|
return 'namespace functions {\n' + '\n'.join(
|
||||||
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
|
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
|
||||||
'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"
|
'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"
|
||||||
|
|
|
@ -73,7 +73,7 @@ def main(
|
||||||
]
|
]
|
||||||
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
|
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
|
||||||
atexit.register(server_process.kill)
|
atexit.register(server_process.kill)
|
||||||
endpoint = f"http://{server_host}:{server_port}/completions"
|
endpoint = f"http://{server_host}:{server_port}"
|
||||||
|
|
||||||
|
|
||||||
# print(chat_template.render([
|
# print(chat_template.render([
|
||||||
|
@ -125,7 +125,7 @@ def main(
|
||||||
|
|
||||||
if chat_request.response_format is not None:
|
if chat_request.response_format is not None:
|
||||||
assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}"
|
assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}"
|
||||||
response_schema = chat_request.response_format.json_schema or {}
|
response_schema = chat_request.response_format.schema or {}
|
||||||
else:
|
else:
|
||||||
response_schema = None
|
response_schema = None
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ def main(
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{endpoint}",
|
f'{endpoint}/completions',
|
||||||
json=data,
|
json=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=None)
|
timeout=None)
|
||||||
|
|
|
@ -14,6 +14,56 @@ class SchemaToTypeScriptConverter:
|
||||||
# // where to get weather.
|
# // where to get weather.
|
||||||
# location: string,
|
# location: string,
|
||||||
# }) => any;
|
# }) => any;
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._refs = {}
|
||||||
|
self._refs_being_resolved = set()
|
||||||
|
|
||||||
|
def resolve_refs(self, schema: dict, url: str):
|
||||||
|
'''
|
||||||
|
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||||
|
replacing $ref with absolute reference URL and populating self._refs with the
|
||||||
|
respective referenced (sub)schema dictionaries.
|
||||||
|
'''
|
||||||
|
def visit(n: dict):
|
||||||
|
if isinstance(n, list):
|
||||||
|
return [visit(x) for x in n]
|
||||||
|
elif isinstance(n, dict):
|
||||||
|
ref = n.get('$ref')
|
||||||
|
if ref is not None and ref not in self._refs:
|
||||||
|
if ref.startswith('https://'):
|
||||||
|
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
|
||||||
|
import requests
|
||||||
|
|
||||||
|
frag_split = ref.split('#')
|
||||||
|
base_url = frag_split[0]
|
||||||
|
|
||||||
|
target = self._refs.get(base_url)
|
||||||
|
if target is None:
|
||||||
|
target = self.resolve_refs(requests.get(ref).json(), base_url)
|
||||||
|
self._refs[base_url] = target
|
||||||
|
|
||||||
|
if len(frag_split) == 1 or frag_split[-1] == '':
|
||||||
|
return target
|
||||||
|
elif ref.startswith('#/'):
|
||||||
|
target = schema
|
||||||
|
ref = f'{url}{ref}'
|
||||||
|
n['$ref'] = ref
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported ref {ref}')
|
||||||
|
|
||||||
|
for sel in ref.split('#')[-1].split('/')[1:]:
|
||||||
|
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||||
|
target = target[sel]
|
||||||
|
|
||||||
|
self._refs[ref] = target
|
||||||
|
else:
|
||||||
|
for v in n.values():
|
||||||
|
visit(v)
|
||||||
|
|
||||||
|
return n
|
||||||
|
return visit(schema)
|
||||||
|
|
||||||
def _desc_comment(self, schema: dict):
|
def _desc_comment(self, schema: dict):
|
||||||
desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None
|
desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None
|
||||||
return f'// {desc}\n' if desc else ''
|
return f'// {desc}\n' if desc else ''
|
||||||
|
@ -78,7 +128,7 @@ class SchemaToTypeScriptConverter:
|
||||||
else:
|
else:
|
||||||
add_component(t, is_required=True)
|
add_component(t, is_required=True)
|
||||||
|
|
||||||
return self._build_object_rule(properties, required, additional_properties=[])
|
return self._build_object_rule(properties, required, additional_properties={})
|
||||||
|
|
||||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||||
items = schema.get('items') or schema['prefixItems']
|
items = schema.get('items') or schema['prefixItems']
|
||||||
|
@ -94,4 +144,4 @@ class SchemaToTypeScriptConverter:
|
||||||
return 'any'
|
return 'any'
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return 'number' if schema_type == 'integer' else schema_type
|
return 'number' if schema_type == 'integer' else schema_type or 'any'
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue