agent: fix response_format
This commit is contained in:
parent
6e52a9ce48
commit
701a66d80f
6 changed files with 81 additions and 17 deletions
|
@ -138,19 +138,28 @@ If you'd like to debug each binary separately (rather than have an agent spawing
|
|||
```bash
|
||||
# C++ server
|
||||
make -j server
|
||||
./server --model mixtral.gguf --port 8081
|
||||
./server \
|
||||
--model mixtral.gguf \
|
||||
--metrics \
|
||||
-ctk q4_0 \
|
||||
-ctv f16 \
|
||||
-c 32768 \
|
||||
--port 8081
|
||||
|
||||
# OpenAI compatibility layer
|
||||
python -m examples.openai \
|
||||
--port 8080
|
||||
--port 8080 \
|
||||
--endpoint http://localhost:8081 \
|
||||
--template_hf_model_id_fallback mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||
--template-hf-model-id-fallback mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||
|
||||
# Or have the OpenAI compatibility layer spawn the C++ server under the hood:
|
||||
# python -m examples.openai --model mixtral.gguf
|
||||
|
||||
# Agent itself:
|
||||
python -m examples.agent --endpoint http://localhost:8080 \
|
||||
--tools examples/agent/tools/example_summaries.py \
|
||||
--format PyramidalSummary \
|
||||
--goal "Create a pyramidal summary of Mankind's recent advancements"
|
||||
```
|
||||
|
||||
## Use existing tools (WIP)
|
||||
|
|
|
@ -10,7 +10,7 @@ import json, requests
|
|||
|
||||
from examples.json_schema_to_grammar import SchemaConverter
|
||||
from examples.agent.tools.std_tools import StandardTools
|
||||
from examples.openai.api import ChatCompletionRequest, ChatCompletionResponse, Message, Tool, ToolFunction
|
||||
from examples.openai.api import ChatCompletionRequest, ChatCompletionResponse, Message, ResponseFormat, Tool, ToolFunction
|
||||
from examples.agent.utils import collect_functions, load_module
|
||||
from examples.openai.prompting import ToolsPromptStyle
|
||||
|
||||
|
@ -46,7 +46,7 @@ def completion_with_tool_usage(
|
|||
else:
|
||||
type_adapter = TypeAdapter(response_model)
|
||||
schema = type_adapter.json_schema()
|
||||
response_format={"type": "json_object", "schema": schema }
|
||||
response_format=ResponseFormat(type="json_object", schema=schema)
|
||||
|
||||
tool_map = {fn.__name__: fn for fn in tools}
|
||||
tools_schemas = [
|
||||
|
@ -77,14 +77,15 @@ def completion_with_tool_usage(
|
|||
if auth:
|
||||
headers["Authorization"] = auth
|
||||
response = requests.post(
|
||||
endpoint,
|
||||
f'{endpoint}/v1/chat/completions',
|
||||
headers=headers,
|
||||
json=request.model_dump(),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed ({response.status_code}): {response.text}")
|
||||
|
||||
response = ChatCompletionResponse(**response.json())
|
||||
response_json = response.json()
|
||||
response = ChatCompletionResponse(**response_json)
|
||||
if verbose:
|
||||
sys.stderr.write(f'# RESPONSE: {response.model_dump_json(indent=2)}\n')
|
||||
if response.error:
|
||||
|
@ -169,7 +170,7 @@ def main(
|
|||
if not endpoint:
|
||||
server_port = 8080
|
||||
server_host = 'localhost'
|
||||
endpoint: str = f'http://{server_host}:{server_port}/v1/chat/completions'
|
||||
endpoint = f'http://{server_host}:{server_port}'
|
||||
if verbose:
|
||||
sys.stderr.write(f"# Starting C++ server with model {model} on {endpoint}\n")
|
||||
cmd = [
|
||||
|
|
|
@ -28,8 +28,8 @@ class Tool(BaseModel):
|
|||
function: ToolFunction
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
type: str
|
||||
json_schema: Optional[Any] = None
|
||||
type: Literal["json_object"]
|
||||
schema: Optional[Dict] = None
|
||||
|
||||
class LlamaCppParams(BaseModel):
|
||||
n_predict: Optional[int] = None
|
||||
|
|
|
@ -712,15 +712,19 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
|
|||
else:
|
||||
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
|
||||
|
||||
_ts_converter = SchemaToTypeScriptConverter()
|
||||
|
||||
# os.environ.get('NO_TS')
|
||||
def _please_respond_with_schema(schema: dict) -> str:
|
||||
# sig = json.dumps(schema, indent=2)
|
||||
_ts_converter = SchemaToTypeScriptConverter()
|
||||
_ts_converter.resolve_refs(schema, 'schema')
|
||||
sig = _ts_converter.visit(schema)
|
||||
return f'Please respond in JSON format with the following schema: {sig}'
|
||||
|
||||
def _tools_typescript_signatures(tools: list[Tool]) -> str:
|
||||
_ts_converter = SchemaToTypeScriptConverter()
|
||||
for tool in tools:
|
||||
_ts_converter.resolve_refs(tool.function.parameters, tool.function.name)
|
||||
|
||||
return 'namespace functions {\n' + '\n'.join(
|
||||
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
|
||||
'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"
|
||||
|
|
|
@ -73,7 +73,7 @@ def main(
|
|||
]
|
||||
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
|
||||
atexit.register(server_process.kill)
|
||||
endpoint = f"http://{server_host}:{server_port}/completions"
|
||||
endpoint = f"http://{server_host}:{server_port}"
|
||||
|
||||
|
||||
# print(chat_template.render([
|
||||
|
@ -125,7 +125,7 @@ def main(
|
|||
|
||||
if chat_request.response_format is not None:
|
||||
assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}"
|
||||
response_schema = chat_request.response_format.json_schema or {}
|
||||
response_schema = chat_request.response_format.schema or {}
|
||||
else:
|
||||
response_schema = None
|
||||
|
||||
|
@ -164,7 +164,7 @@ def main(
|
|||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{endpoint}",
|
||||
f'{endpoint}/completions',
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=None)
|
||||
|
|
|
@ -14,6 +14,56 @@ class SchemaToTypeScriptConverter:
|
|||
# // where to get weather.
|
||||
# location: string,
|
||||
# }) => any;
|
||||
|
||||
def __init__(self):
|
||||
self._refs = {}
|
||||
self._refs_being_resolved = set()
|
||||
|
||||
def resolve_refs(self, schema: dict, url: str):
|
||||
'''
|
||||
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||
replacing $ref with absolute reference URL and populating self._refs with the
|
||||
respective referenced (sub)schema dictionaries.
|
||||
'''
|
||||
def visit(n: dict):
|
||||
if isinstance(n, list):
|
||||
return [visit(x) for x in n]
|
||||
elif isinstance(n, dict):
|
||||
ref = n.get('$ref')
|
||||
if ref is not None and ref not in self._refs:
|
||||
if ref.startswith('https://'):
|
||||
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
|
||||
import requests
|
||||
|
||||
frag_split = ref.split('#')
|
||||
base_url = frag_split[0]
|
||||
|
||||
target = self._refs.get(base_url)
|
||||
if target is None:
|
||||
target = self.resolve_refs(requests.get(ref).json(), base_url)
|
||||
self._refs[base_url] = target
|
||||
|
||||
if len(frag_split) == 1 or frag_split[-1] == '':
|
||||
return target
|
||||
elif ref.startswith('#/'):
|
||||
target = schema
|
||||
ref = f'{url}{ref}'
|
||||
n['$ref'] = ref
|
||||
else:
|
||||
raise ValueError(f'Unsupported ref {ref}')
|
||||
|
||||
for sel in ref.split('#')[-1].split('/')[1:]:
|
||||
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel]
|
||||
|
||||
self._refs[ref] = target
|
||||
else:
|
||||
for v in n.values():
|
||||
visit(v)
|
||||
|
||||
return n
|
||||
return visit(schema)
|
||||
|
||||
def _desc_comment(self, schema: dict):
|
||||
desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None
|
||||
return f'// {desc}\n' if desc else ''
|
||||
|
@ -78,7 +128,7 @@ class SchemaToTypeScriptConverter:
|
|||
else:
|
||||
add_component(t, is_required=True)
|
||||
|
||||
return self._build_object_rule(properties, required, additional_properties=[])
|
||||
return self._build_object_rule(properties, required, additional_properties={})
|
||||
|
||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||
items = schema.get('items') or schema['prefixItems']
|
||||
|
@ -94,4 +144,4 @@ class SchemaToTypeScriptConverter:
|
|||
return 'any'
|
||||
|
||||
else:
|
||||
return 'number' if schema_type == 'integer' else schema_type
|
||||
return 'number' if schema_type == 'integer' else schema_type or 'any'
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue