This commit is contained in:
Olivier Chafik 2024-10-26 22:17:55 +01:00 committed by GitHub
commit bc2b15012c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 569 additions and 370 deletions

View file

@ -9,6 +9,11 @@
#include <unordered_set>
#include <vector>
#if defined(LLAMA_USE_CURL)
#include <curl/curl.h>
#include <curl/easy.h>
#endif
using json = nlohmann::ordered_json;
template <typename Iterator>
@ -387,15 +392,21 @@ static std::string format_literal(const std::string & literal) {
return "\"" + escaped + "\"";
}
static size_t json_schema_ref_curl_write_callback(char *ptr, size_t size, size_t nmemb, void *data) {
auto &response = *static_cast<std::ostringstream *>(data);
response.write((char *)ptr, size * nmemb);
return size * nmemb;
}
class SchemaConverter {
private:
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
std::map<std::string, std::string> _rules;
std::unordered_map<std::string, json> _refs;
std::unordered_set<std::string> _refs_being_resolved;
std::vector<std::string> _errors;
std::vector<std::string> _warnings;
std::unordered_map<std::string, json> _external_refs;
std::vector<json> _ref_context;
std::string _add_rule(const std::string & name, const std::string & rule) {
std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
@ -683,17 +694,6 @@ private:
return out.str();
}
std::string _resolve_ref(const std::string & ref) {
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
_refs_being_resolved.insert(ref);
json resolved = _refs[ref];
ref_name = visit(resolved, ref_name);
_refs_being_resolved.erase(ref);
}
return ref_name;
}
std::string _build_object_rule(
const std::vector<std::pair<std::string, json>> & properties,
const std::unordered_set<std::string> & required,
@ -807,86 +807,131 @@ private:
}
public:
SchemaConverter(
const std::function<json(const std::string &)> & fetch_json,
bool dotall)
: _fetch_json(fetch_json), _dotall(dotall)
SchemaConverter(bool dotall) : _dotall(dotall)
{
_rules["space"] = SPACE_RULE;
}
void resolve_refs(json & schema, const std::string & url) {
/*
* Resolves all $ref fields in the given schema, fetching any remote schemas,
* replacing each $ref with absolute reference URL and populates _refs with the
* respective referenced (sub)schema dictionaries.
*/
std::function<void(json &)> visit_refs = [&](json & n) {
if (n.is_array()) {
for (auto & x : n) {
visit_refs(x);
}
} else if (n.is_object()) {
if (n.contains("$ref")) {
std::string ref = n["$ref"];
if (_refs.find(ref) == _refs.end()) {
json target;
if (ref.find("https://") == 0) {
std::string base_url = ref.substr(0, ref.find('#'));
auto it = _refs.find(base_url);
if (it != _refs.end()) {
target = it->second;
} else {
// Fetch the referenced schema and resolve its refs
auto referenced = _fetch_json(ref);
resolve_refs(referenced, base_url);
_refs[base_url] = referenced;
}
if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) {
return;
}
} else if (ref.find("#/") == 0) {
target = schema;
n["$ref"] = url + ref;
ref = url + ref;
} else {
_errors.push_back("Unsupported ref: " + ref);
return;
}
std::string pointer = ref.substr(ref.find('#') + 1);
std::vector<std::string> tokens = split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return;
}
target = target[sel];
}
_refs[ref] = target;
}
} else {
for (auto & kv : n.items()) {
visit_refs(kv.value());
}
}
#if defined(LLAMA_USE_CURL)
_fetch_json = [&](const std::string & url) {
// TODO: implement HTTP caching semantics.
static std::unordered_map<std::string, json> cache;
auto it = cache.find(url);
if (it != cache.end()) {
return it->second;
}
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
if (!curl) {
fprintf(stderr, "%s: error initializing libcurl\n", __func__);
return json::object();
}
};
visit_refs(schema);
std::ostringstream response;
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
#if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows.
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, &json_schema_ref_curl_write_callback);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
throw std::runtime_error("Failed to fetch " + url + ": " + curl_easy_strerror(res));
}
response << '\0';
return cache[url] = json::parse(response.str());
};
#else
_fetch_json = [&](const std::string &) {
_errors.push_back("Fetching external refs not supported, please recompile with CURL support.");
return json::object();
};
#endif
}
std::string _generate_constant_rule(const json & value) {
return format_literal(value.dump());
}
struct ResolvedRef {
json target;
std::string name;
bool is_local;
};
ResolvedRef _resolve_ref(const std::string & ref) {
auto parts = split(ref, "#");
if (parts.size() > 2) {
_errors.push_back("Unsupported ref: " + ref);
return {json(), "", false};
}
const auto & url = parts[0];
json target;
bool is_local = url.empty();
if (is_local) {
if (_ref_context.empty()) {
_errors.push_back("Error resolving ref " + ref + ": no context");
return {json(), "", false};
}
target = _ref_context.back();
} else {
auto it = _external_refs.find(url);
if (it != _external_refs.end()) {
target = it->second;
} else if (url.rfind("https://", 0) == 0) {
// Fetch the referenced schema and resolve its refs
target = _fetch_json(url);
_external_refs[url] = target;
} else {
_errors.push_back("Error resolving ref " + ref + ": unsupported url scheme");
return {json(), "", false};
}
}
if (parts.size() == 1) {
return {target, "", is_local};
}
auto tokens = split(parts[1], "/");
for (size_t i = 1; i < tokens.size(); ++i) {
const auto & sel = tokens[i];
if (target.is_null() || !target.contains(sel)) {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return {json(), "", false};
}
target = target[sel];
}
return {target, tokens.empty() ? "" : tokens[tokens.size() - 1], is_local};
}
std::string visit(const json & schema, const std::string & name) {
json schema_type = schema.contains("type") ? schema["type"] : json();
std::string schema_format = schema.contains("format") ? schema["format"].get<std::string>() : "";
std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
if (schema.contains("$ref")) {
return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
if (_ref_context.empty()) {
_ref_context.push_back(schema);
auto ret = visit(schema, name);
_ref_context.pop_back();
return ret;
}
if (schema.contains("$ref") && schema["$ref"].is_string()) {
const auto & ref = schema["$ref"].get<std::string>();
auto resolved = _resolve_ref(ref);
if (resolved.target.is_null()) {
return "";
}
if (!resolved.is_local) {
_ref_context.push_back(resolved.target);
}
auto ret = visit(resolved.target, (name.empty() || resolved.name.empty()) ? name : resolved.name);
if (!resolved.is_local) {
_ref_context.pop_back();
}
return ret;
} else if (schema.contains("oneOf") || schema.contains("anyOf")) {
std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
@ -906,55 +951,6 @@ public:
enum_values.push_back(_generate_constant_rule(v));
}
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
} else if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") ||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
std::unordered_set<std::string> required;
if (schema.contains("required") && schema["required"].is_array()) {
for (const auto & item : schema["required"]) {
if (item.is_string()) {
required.insert(item.get<std::string>());
}
}
}
std::vector<std::pair<std::string, json>> properties;
if (schema.contains("properties")) {
for (const auto & prop : schema["properties"].items()) {
properties.emplace_back(prop.key(), prop.value());
}
}
return _add_rule(rule_name,
_build_object_rule(
properties, required, name,
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
std::unordered_set<std::string> required;
std::vector<std::pair<std::string, json>> properties;
std::string hybrid_name = name;
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
if (comp_schema.contains("$ref")) {
add_component(_refs[comp_schema["$ref"]], is_required);
} else if (comp_schema.contains("properties")) {
for (const auto & prop : comp_schema["properties"].items()) {
properties.emplace_back(prop.key(), prop.value());
if (is_required) {
required.insert(prop.key());
}
}
} else {
// todo warning
}
};
for (auto & t : schema["allOf"]) {
if (t.contains("anyOf")) {
for (auto & tt : t["anyOf"]) {
add_component(tt, false);
}
} else {
add_component(t, true);
}
}
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
if (items.is_array()) {
@ -1005,8 +1001,71 @@ public:
_build_min_max_int(min_value, max_value, out);
out << ") space";
return _add_rule(rule_name, out.str());
} else if (schema.empty() || schema_type == "object") {
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
} else if ((schema_type.is_null() || schema_type == "object")) {
std::unordered_set<std::string> required;
std::vector<std::pair<std::string, json>> properties;
auto is_explicit_object = schema_type == "object";
json additional_properties;
if (schema.contains("additionalProperties")) {
is_explicit_object = true;
additional_properties = schema["additionalProperties"];
}
if (schema.contains("properties") && schema["properties"].is_object()) {
is_explicit_object = true;
for (const auto & prop : schema["properties"].items()) {
if (prop.value().is_object()) {
properties.emplace_back(prop.key(), prop.value());
}
}
}
if (schema.contains("required") && schema["required"].is_array()) {
for (const auto & item : schema["required"]) {
if (item.is_string()) {
required.insert(item.get<std::string>());
}
}
}
if (schema.contains("allOf") && schema["allOf"].is_array()) {
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
if (comp_schema.contains("$ref") && comp_schema["$ref"].is_string()) {
auto resolved = _resolve_ref(comp_schema["$ref"].get<std::string>());
add_component(resolved.target, is_required);
} else if (comp_schema.contains("properties")) {
for (const auto & prop : comp_schema["properties"].items()) {
properties.emplace_back(prop.key(), prop.value());
if (is_required) {
required.insert(prop.key());
}
}
if (comp_schema.contains("additionalProperties")) {
if (additional_properties.is_null()) {
additional_properties = comp_schema["additionalProperties"];
} else if (additional_properties != comp_schema["additionalProperties"]) {
_warnings.push_back("Inconsistent additionalProperties in allOf");
}
}
} else {
_warnings.push_back("Unsupported allOf schema");
}
};
for (auto & t : schema["allOf"]) {
if (t.contains("anyOf")) {
for (auto & tt : t["anyOf"]) {
add_component(tt, false);
}
} else {
add_component(t, true);
}
}
}
if (properties.empty() && (additional_properties == true || additional_properties.is_null())) {
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
}
auto default_additional_properties = is_explicit_object ? json() : json(false);
return _add_rule(rule_name,
_build_object_rule(
properties, required, name,
additional_properties.is_null() ? default_additional_properties : additional_properties));
} else {
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
_errors.push_back("Unrecognized schema: " + schema.dump());
@ -1036,9 +1095,8 @@ public:
};
std::string json_schema_to_grammar(const json & schema) {
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
SchemaConverter converter(/* dotall= */ false);
auto copy = schema;
converter.resolve_refs(copy, "input");
converter.visit(copy, "");
converter.check_errors();
return converter.format_grammar();

View file

@ -0,0 +1,27 @@
/*
JSON Schema to Grammar converter (JavaScript version)
There are C++ and Python converters w/ the same features.
(More flags are currently exposed by the Python version)
Usage:
node examples/json_schema_to_grammar.mjs schema.json
node examples/json_schema_to_grammar.mjs https://json.schemastore.org/tsconfig.json
echo '{"type": "object"}' | node examples/json_schema_to_grammar.mjs -
*/
import { readFileSync } from "fs"
import { SchemaConverter } from "./server/public/json-schema-to-grammar.mjs"
import fs from 'fs'
const [, , file] = process.argv
let schema;
if (file === '-') {
schema = JSON.parse(fs.readFileSync(0, 'utf8'))
} else if (file.startsWith('https://')) {
schema = await (await fetch(file)).json()
} else {
schema = JSON.parse(readFileSync(file, "utf8"));
}
const converter = new SchemaConverter({})
converter.visit(schema, '')
console.log(converter.formatGrammar())

View file

@ -1,4 +1,16 @@
#!/usr/bin/env python3
'''
JSON Schema to Grammar conversion
There are C++ and JavaScript converters w/ the same features.
Usage:
python examples/json_schema_to_grammar.py schema.json
python examples/json_schema_to_grammar.py https://json.schemastore.org/tsconfig.json
echo '{"type": "object"}' | python examples/json_schema_to_grammar.py -
Also see https://github.com/ggerganov/llama.cpp/tree/master/grammars
'''
from __future__ import annotations
import argparse
@ -237,16 +249,15 @@ ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
class SchemaConverter:
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
def __init__(self, *, prop_order, dotall, raw_pattern):
self._prop_order = prop_order
self._allow_fetch = allow_fetch
self._dotall = dotall
self._raw_pattern = raw_pattern
self._rules = {
'space': SPACE_RULE,
}
self._refs = {}
self._refs_being_resolved = set()
self._external_refs = {}
self._ref_context = []
def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
@ -334,51 +345,6 @@ class SchemaConverter:
self._rules[key] = rule
return key
def resolve_refs(self, schema: dict, url: str):
'''
Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries.
'''
def visit(n: dict):
if isinstance(n, list):
return [visit(x) for x in n]
elif isinstance(n, dict):
ref = n.get('$ref')
if ref is not None and ref not in self._refs:
if ref.startswith('https://'):
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
import requests
frag_split = ref.split('#')
base_url = frag_split[0]
target = self._refs.get(base_url)
if target is None:
target = self.resolve_refs(requests.get(ref).json(), base_url)
self._refs[base_url] = target
if len(frag_split) == 1 or frag_split[-1] == '':
return target
elif ref.startswith('#/'):
target = schema
ref = f'{url}{ref}'
n['$ref'] = ref
else:
raise ValueError(f'Unsupported ref {ref}')
for sel in ref.split('#')[-1].split('/')[1:]:
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
self._refs[ref] = target
else:
for v in n.values():
visit(v)
return n
return visit(schema)
def _generate_union_rule(self, name, alt_schemas):
return ' | '.join((
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
@ -543,25 +509,64 @@ class SchemaConverter:
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
def _resolve_ref(self, ref):
ref_name = ref.split('/')[-1]
if ref_name not in self._rules and ref not in self._refs_being_resolved:
self._refs_being_resolved.add(ref)
resolved = self._refs[ref]
ref_name = self.visit(resolved, ref_name)
self._refs_being_resolved.remove(ref)
return ref_name
def _generate_constant_rule(self, value):
return self._format_literal(json.dumps(value))
class ResolvedRef:
def __init__(self, target: Any, name: str, is_local: bool):
self.target = target
self.name = name
self.is_local = is_local
def _resolve_ref(self, ref: str):
parts = ref.split('#')
assert len(parts) <= 2, f'Unsupported ref: {ref}'
url = parts[0]
target = None
is_local = not url
if is_local:
assert self._ref_context, f'Error resolving ref {ref}: no context'
target = self._ref_context[-1]
else:
target = self._external_refs.get(url)
if target is None:
# Fetch the referenced schema and resolve its refs
assert url.startswith("https://"), f"Error resolving ref {ref}: unsupported url scheme"
import requests
target = requests.get(url).json()
self._external_refs[url] = target
if len(parts) == 1:
return self.ResolvedRef(target, '', is_local)
else:
tokens = parts[1].split('/')
for sel in tokens[1:]:
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
return self.ResolvedRef(target, tokens[-1] if tokens else '', is_local)
def visit(self, schema, name):
schema_type = schema.get('type')
schema_format = schema.get('format')
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
if not self._ref_context:
self._ref_context.append(schema)
try:
return self.visit(schema, name)
finally:
self._ref_context.pop()
if (ref := schema.get('$ref')) is not None:
return self._add_rule(rule_name, self._resolve_ref(ref))
resolved = self._resolve_ref(ref)
if not resolved.is_local:
self._ref_context.append(resolved.target)
try:
return self.visit(resolved.target, name if name == '' or resolved.name == '' else resolved.name)
finally:
if not resolved.is_local:
self._ref_context.pop()
elif 'oneOf' in schema or 'anyOf' in schema:
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
@ -576,36 +581,6 @@ class SchemaConverter:
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
return self._add_rule(rule_name, rule)
elif schema_type in (None, 'object') and \
('properties' in schema or \
('additionalProperties' in schema and schema['additionalProperties'] is not True)):
required = set(schema.get('required', []))
properties = list(schema.get('properties', {}).items())
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
elif schema_type in (None, 'object') and 'allOf' in schema:
required = set()
properties = []
hybrid_name = name
def add_component(comp_schema, is_required):
if (ref := comp_schema.get('$ref')) is not None:
comp_schema = self._refs[ref]
if 'properties' in comp_schema:
for prop_name, prop_schema in comp_schema['properties'].items():
properties.append((prop_name, prop_schema))
if is_required:
required.add(prop_name)
for t in schema['allOf']:
if 'anyOf' in t:
for tt in t['anyOf']:
add_component(tt, is_required=False)
else:
add_component(t, is_required=True)
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
items = schema.get('items') or schema['prefixItems']
if isinstance(items, list):
@ -660,8 +635,44 @@ class SchemaConverter:
out.append(") space")
return self._add_rule(rule_name, ''.join(out))
elif (schema_type == 'object') or (len(schema) == 0):
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
elif (schema_type == 'object') or (schema_type is None):
required = set(schema.get('required', []))
properties = list(schema.get('properties', {}).items())
is_explicit_object = schema_type == 'object' or 'properties' in schema or 'additionalProperties' in schema
additional_properties = schema.get('additionalProperties')
def add_component(comp_schema, is_required):
if (ref := comp_schema.get('$ref')) is not None:
resolved = self._resolve_ref(ref)
comp_schema = resolved.target
if 'properties' in comp_schema:
for prop_name, prop_schema in comp_schema['properties'].items():
properties.append((prop_name, prop_schema))
if is_required:
required.add(prop_name)
if 'additionalProperties' in comp_schema:
if additional_properties is None:
additional_properties = comp_schema['additionalProperties']
elif additional_properties != comp_schema['additionalProperties']:
raise ValueError('Inconsistent additionalProperties in allOf')
for t in schema.get('allOf', []):
if 'anyOf' in t:
for tt in t['anyOf']:
add_component(tt, is_required=False)
else:
add_component(t, is_required=True)
if not properties and (additional_properties == True or additional_properties is None):
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
default_additional_properties = None if is_explicit_object else False
return self._add_rule(
rule_name,
self._build_object_rule(
properties, required, name,
additional_properties if additional_properties is not None else default_additional_properties))
else:
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
@ -767,11 +778,6 @@ def main(args_in = None):
given precedence over optional properties.
'''
)
parser.add_argument(
'--allow-fetch',
action='store_true',
default=False,
help='Whether to allow fetching referenced schemas over HTTPS')
parser.add_argument(
'--dotall',
action='store_true',
@ -799,10 +805,8 @@ def main(args_in = None):
schema = json.load(f)
converter = SchemaConverter(
prop_order={name: idx for idx, name in enumerate(args.prop_order)},
allow_fetch=args.allow_fetch,
dotall=args.dotall,
raw_pattern=args.raw_pattern)
schema = converter.resolve_refs(schema, url)
converter.visit(schema, '')
print(converter.format_grammar())

View file

@ -26,9 +26,8 @@ const propOrder = grammarJsonSchemaPropOrder
let grammar = null
if (grammarJsonSchemaFile) {
let schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8'))
const schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8'))
const converter = new SchemaConverter({prop_order: propOrder, allow_fetch: true})
schema = await converter.resolveRefs(schema, grammarJsonSchemaFile)
converter.visit(schema, '')
grammar = converter.formatGrammar()
}

View file

@ -564,14 +564,13 @@ const ConfigForm = (props) => {
const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value
const convertJSONSchemaGrammar = async () => {
try {
let schema = JSON.parse(params.value.grammar)
const schema = JSON.parse(params.value.grammar)
const converter = new SchemaConverter({
prop_order: grammarJsonSchemaPropOrder.value
.split(',')
.reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {}),
allow_fetch: true,
})
schema = await converter.resolveRefs(schema, 'input')
converter.visit(schema, '')
params.value = {
...params.value,

View file

@ -861,14 +861,13 @@
const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value
const convertJSONSchemaGrammar = async () => {
try {
let schema = JSON.parse(params.value.grammar)
const schema = JSON.parse(params.value.grammar)
const converter = new SchemaConverter({
prop_order: grammarJsonSchemaPropOrder.value
.split(',')
.reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {}),
allow_fetch: true,
})
schema = await converter.resolveRefs(schema, 'input')
converter.visit(schema, '')
params.value = {
...params.value,

View file

@ -264,11 +264,11 @@ const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('^$.[]()|{}*+?');
export class SchemaConverter {
constructor(options) {
this._propOrder = options.prop_order || {};
this._allowFetch = options.allow_fetch || false;
this._dotall = options.dotall || false;
this._rules = {'space': SPACE_RULE};
this._refs = {};
this._refsBeingResolved = new Set();
this._externalRefs = new Map();
this._refContext = [];
}
_formatLiteral(literal) {
@ -306,60 +306,6 @@ export class SchemaConverter {
return key;
}
async resolveRefs(schema, url) {
const visit = async (n) => {
if (Array.isArray(n)) {
return Promise.all(n.map(visit));
} else if (typeof n === 'object' && n !== null) {
let ref = n.$ref;
let target;
if (ref !== undefined && !this._refs[ref]) {
if (ref.startsWith('https://')) {
if (!this._allowFetch) {
throw new Error('Fetching remote schemas is not allowed (use --allow-fetch for force)');
}
const fetch = (await import('node-fetch')).default;
const fragSplit = ref.split('#');
const baseUrl = fragSplit[0];
target = this._refs[baseUrl];
if (!target) {
target = await this.resolveRefs(await fetch(ref).then(res => res.json()), baseUrl);
this._refs[baseUrl] = target;
}
if (fragSplit.length === 1 || fragSplit[fragSplit.length - 1] === '') {
return target;
}
} else if (ref.startsWith('#/')) {
target = schema;
ref = `${url}${ref}`;
n.$ref = ref;
} else {
throw new Error(`Unsupported ref ${ref}`);
}
const selectors = ref.split('#')[1].split('/').slice(1);
for (const sel of selectors) {
if (!target || !(sel in target)) {
throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`);
}
target = target[sel];
}
this._refs[ref] = target;
} else {
await Promise.all(Object.values(n).map(visit));
}
}
return n;
};
return visit(schema);
}
_generateUnionRule(name, altSchemas) {
return altSchemas
.map((altSchema, i) => this.visit(altSchema, `${name ?? ''}${name ? '-' : 'alternative-'}${i}`))
@ -590,29 +536,72 @@ export class SchemaConverter {
return out.join('');
}
_resolveRef(ref) {
let refName = ref.split('/').pop();
if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) {
this._refsBeingResolved.add(ref);
const resolved = this._refs[ref];
refName = this.visit(resolved, refName);
this._refsBeingResolved.delete(ref);
}
return refName;
}
_generateConstantRule(value) {
return this._formatLiteral(JSON.stringify(value));
}
_resolveRef(ref) {
const parts = ref.split('#');
if (parts.length !== 2) {
throw new Error(`Unsupported ref: ${ref}`);
}
const url = parts[0];
let target = null;
let isLocal = !url;
if (isLocal) {
if (this._refContext.length === 0) {
throw new Error(`Error resolving ref ${ref}: no context`);
}
target = this._refContext[this._refContext.length - 1];
} else {
target = this._externalRefs.get(url);
if (target === undefined) {
// Fetch the referenced schema and resolve its refs
if (!url.startsWith('https://')) {
throw new Error(`Error resolving ref ${ref}: unsupported url scheme`);
}
target = this._fetchJson(url);
this._externalRefs.set(url, target);
}
}
const tokens = parts[1].split('/');
for (const sel of tokens.slice(1)) {
if (target === null || !(sel in target)) {
throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`);
}
target = target[sel];
}
const name = tokens[tokens.length - 1] || '';
return {target, name, isLocal};
}
visit(schema, name) {
const schemaType = schema.type;
const schemaFormat = schema.format;
const ruleName = name in RESERVED_NAMES ? name + '-' : name == '' ? 'root' : name;
if (this._refContext.length === 0) {
this._refContext.push(schema);
try {
return this.visit(schema, name);
} finally {
this._refContext.pop();
}
}
const ref = schema.$ref;
if (ref !== undefined) {
return this._addRule(ruleName, this._resolveRef(ref));
const resolved = this._resolveRef(ref);
if (!resolved.isLocal) {
this._refContext.push(resolved.target);
}
try {
return this.visit(resolved.target, name === '' || resolved.name === '' ? name : resolved.name);
} finally {
if (!resolved.isLocal) {
this._refContext.pop();
}
}
} else if (schema.oneOf || schema.anyOf) {
return this._addRule(ruleName, this._generateUnionRule(name, schema.oneOf || schema.anyOf));
} else if (Array.isArray(schemaType)) {
@ -622,42 +611,6 @@ export class SchemaConverter {
} else if ('enum' in schema) {
const rule = '(' + schema.enum.map(v => this._generateConstantRule(v)).join(' | ') + ') space';
return this._addRule(ruleName, rule);
} else if ((schemaType === undefined || schemaType === 'object') &&
('properties' in schema ||
('additionalProperties' in schema && schema.additionalProperties !== true))) {
const required = new Set(schema.required || []);
const properties = Object.entries(schema.properties ?? {});
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
} else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) {
const required = new Set();
const properties = [];
const addComponent = (compSchema, isRequired) => {
const ref = compSchema.$ref;
if (ref !== undefined) {
compSchema = this._refs[ref];
}
if ('properties' in compSchema) {
for (const [propName, propSchema] of Object.entries(compSchema.properties)) {
properties.push([propName, propSchema]);
if (isRequired) {
required.add(propName);
}
}
}
};
for (const t of schema.allOf) {
if ('anyOf' in t) {
for (const tt of t.anyOf) {
addComponent(tt, false);
}
} else {
addComponent(t, true);
}
}
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null));
} else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
const items = schema.items ?? schema.prefixItems;
if (Array.isArray(items)) {
@ -706,8 +659,58 @@ export class SchemaConverter {
_generateMinMaxInt(minValue, maxValue, out);
out.push(") space");
return this._addRule(ruleName, out.join(''));
} else if ((schemaType === 'object') || (Object.keys(schema).length === 0)) {
return this._addRule(ruleName, this._addPrimitive('object', PRIMITIVE_RULES['object']));
} else if (schemaType === undefined || schemaType === 'object') {
const required = new Set(schema.required || []);
const properties = Object.entries(schema.properties ?? {});
const isExplicitObject = schemaType === 'object' || 'properties' in schema || 'additionalProperties' in schema;
let additionalProperties = schema.additionalProperties;
const addComponent = (compSchema, isRequired) => {
const ref = compSchema.$ref;
if (ref !== undefined) {
const resolved = this._resolveRef(ref);
compSchema = resolved.target;
}
if ('properties' in compSchema) {
for (const [propName, propSchema] of Object.entries(compSchema.properties)) {
properties.push([propName, propSchema]);
if (isRequired) {
required.add(propName);
}
}
if ('additionalProperties' in compSchema) {
if (additionalProperties === null) {
additionalProperties = compSchema.additionalProperties;
} else if (additionalProperties !== compSchema.additionalProperties) {
throw new Error('Inconsistent additionalProperties in allOf');
}
}
}
};
if ('allOf' in schema) {
for (const t of schema.allOf) {
if ('anyOf' in t) {
for (const tt of t.anyOf) {
addComponent(tt, false);
}
} else {
addComponent(t, true);
}
}
}
if (properties.length === 0 && (additionalProperties === true || additionalProperties == null)) {
return this._addRule(ruleName, this._addPrimitive('object', PRIMITIVE_RULES['object']));
}
const defaultAdditionalProperties = isExplicitObject ? null : false;
return this._addRule(
ruleName,
this._buildObjectRule(properties, required, name, additionalProperties ?? defaultAdditionalProperties)
);
} else {
if (!(schemaType in PRIMITIVE_RULES)) {
throw new Error(`Unrecognized schema: ${JSON.stringify(schema)}`);

View file

@ -634,14 +634,13 @@
const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value
const convertJSONSchemaGrammar = async () => {
try {
let schema = JSON.parse(params.value.grammar)
const schema = JSON.parse(params.value.grammar)
const converter = new SchemaConverter({
prop_order: grammarJsonSchemaPropOrder.value
.split(',')
.reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {}),
allow_fetch: true,
})
schema = await converter.resolveRefs(schema, 'input')
converter.visit(schema, '')
params.value = {
...params.value,

View file

@ -637,14 +637,13 @@
const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value
const convertJSONSchemaGrammar = async () => {
try {
let schema = JSON.parse(params.value.grammar)
const schema = JSON.parse(params.value.grammar)
const converter = new SchemaConverter({
prop_order: grammarJsonSchemaPropOrder.value
.split(',')
.reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {}),
allow_fetch: true,
})
schema = await converter.resolveRefs(schema, 'input')
converter.visit(schema, '')
params.value = {
...params.value,

View file

@ -123,8 +123,8 @@ You can use GBNF grammars:
- For the `/chat/completions` endpoint, passed inside the `response_format` body field (e.g. `{"type", "json_object", "schema": {"items": {}}}` or `{ type: "json_schema", json_schema: {"schema": ...} }`)
- In [llama-cli](../examples/main), passed as the `--json` / `-j` flag
- To convert to a grammar ahead of time:
- in CLI, with [examples/json_schema_to_grammar.py](../examples/json_schema_to_grammar.py)
- in JavaScript with [json-schema-to-grammar.mjs](../examples/server/public/json-schema-to-grammar.mjs) (this is used by the [server](../examples/server)'s Web UI)
- in Python with [`python examples/json_schema_to_grammar.py schema.json`](../examples/json_schema_to_grammar.py)
- in JavaScript with [`node examples/json_schema_to_grammar.mjs schema.json`](../examples/json_schema_to_grammar.mjs) (uses same lib as the [server](../examples/server)'s Web UI)
Take a look at [tests](../tests/test-json-schema-to-grammar.cpp) to see which features are likely supported (you'll also find usage examples in https://github.com/ggerganov/llama.cpp/pull/5978, https://github.com/ggerganov/llama.cpp/pull/6659 & https://github.com/ggerganov/llama.cpp/pull/6555).
@ -185,12 +185,10 @@ Here is also a list of known limitations (contributions welcome):
- `additionalProperties` defaults to `false` (produces faster grammars + reduces hallucinations).
- `"additionalProperties": true` may produce keys that contain unescaped newlines.
- Unsupported features are skipped silently. It is currently advised to use the command-line Python converter (see above) to see any warnings, and to inspect the resulting grammar / test it w/ [llama-gbnf-validator](../examples/gbnf-validator/gbnf-validator.cpp).
- Can't mix `properties` w/ `anyOf` / `oneOf` in the same type (https://github.com/ggerganov/llama.cpp/issues/7703)
- Can't mix `properties` w/ `oneOf` in the same type (https://github.com/ggerganov/llama.cpp/issues/7703)
- [prefixItems](https://json-schema.org/draft/2020-12/json-schema-core#name-prefixitems) is broken (but [items](https://json-schema.org/draft/2020-12/json-schema-core#name-items) works)
- `minimum`, `exclusiveMinimum`, `maximum`, `exclusiveMaximum`: only supported for `"type": "integer"` for now, not `number`
- Nested `$ref`s are broken (https://github.com/ggerganov/llama.cpp/issues/8073)
- [pattern](https://json-schema.org/draft/2020-12/json-schema-validation#name-pattern)s must start with `^` and end with `$`
- Remote `$ref`s not supported in the C++ version (Python & JavaScript versions fetch https refs)
- `string` [formats](https://json-schema.org/draft/2020-12/json-schema-validation#name-defined-formats) lack `uri`, `email`
- No [`patternProperties`](https://json-schema.org/draft/2020-12/json-schema-core#name-patternproperties)

View file

@ -1,10 +0,0 @@
import { readFileSync } from "fs"
import { SchemaConverter } from "../examples/server/public/json-schema-to-grammar.mjs"
const [, , file] = process.argv
const url = `file://${file}`
let schema = JSON.parse(readFileSync(file, "utf8"));
const converter = new SchemaConverter({})
schema = await converter.resolveRefs(schema, url)
converter.visit(schema, '')
console.log(converter.formatGrammar())

View file

@ -1293,6 +1293,46 @@ static void test_json_schema() {
// R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""",
}
);
test_schema(
"nested refs + mix of properties and allOf",
// Schema
R"""({
"properties": {
"common": {"$ref": "#/$defs/SomeVal"}
},
"allOf": [
{"$ref": "#/$defs/foo"},
{"$ref": "#/$defs/bar"},
{
"anyOf": [
{"$ref": "#/$defs/baz"},
{"$ref": "#/$defs/bam"}
]
}
],
"required": ["common"],
"$defs": {
"SomeVal": {"type": "number"},
"foo": {"properties": {"a": {"$ref": "#/$defs/SomeVal"}}},
"bar": {"properties": {"b": {"$ref": "#/$defs/SomeVal"}}},
"bam": {"properties": {"c": {"$ref": "#/$defs/SomeVal"}}},
"baz": {"properties": {"d": {"$ref": "#/$defs/SomeVal"}}}
}
})""",
// Passing strings
{
R"""({"common": 0, "a": 0, "b": 0})""",
R"""({"common": 0, "a": 0, "b": 0, "d": 0, "c": 0})""",
},
// Failing strings
{
R"""({})""",
R"""({"common": "", "a": "", "b": ""})""",
R"""({"a": 0, "b": 0})""",
R"""({"common": 0, "a": "0", "b": 0, "c": 0, "d": 0})""",
}
);
}
int main() {

View file

@ -346,6 +346,48 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
)"""
});
test({
SUCCESS,
"nested $refs (https://github.com/ggerganov/llama.cpp/issues/8073)",
R"""({
"type": "array",
"minItems": 15,
"maxItems": 15,
"items": { "$ref": "#/$defs/talk" },
"$defs": {
"characters": { "enum": ["Biff", "Alice"] },
"emotes": { "enum": ["EXCLAMATION", "CONFUSION", "CHEERFUL", "LOVE", "ANGRY"] },
"talk": {
"type": "object",
"required": [ "character", "emote", "dialog" ],
"properties": {
"character": { "$ref": "#/$defs/characters" },
"emote": { "$ref": "#/$defs/emotes" },
"dialog": {
"type": "string",
"minLength": 1,
"maxLength": 200
}
},
"additionalProperties": false
}
}
})""",
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
characters ::= ("\"Biff\"" | "\"Alice\"") space
emotes ::= ("\"EXCLAMATION\"" | "\"CONFUSION\"" | "\"CHEERFUL\"" | "\"LOVE\"" | "\"ANGRY\"") space
root ::= "[" space talk ("," space talk){14,14} "]" space
space ::= | " " | "\n" [ \t]{0,20}
talk ::= "{" space talk-character-kv "," space talk-emote-kv "," space talk-dialog-kv "}" space
talk-character-kv ::= "\"character\"" space ":" space characters
talk-dialog ::= "\"" char{1,200} "\"" space
talk-dialog-kv ::= "\"dialog\"" space ":" space talk-dialog
talk-emote-kv ::= "\"emote\"" space ":" space emotes
)""",
});
test({
SUCCESS,
"exotic formats",
@ -1105,10 +1147,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
}
})""",
R"""(
a-kv ::= "\"a\"" space ":" space string
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
foo ::= "{" space foo-a-kv "}" space
foo-a-kv ::= "\"a\"" space ":" space string
root ::= foo
root ::= "{" space a-kv "}" space
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
@ -1124,17 +1165,18 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
],
"definitions": {
"foo": {
"properties": {"a": {"type": "number"}}
"properties": {"a": {"type": "number"}},
"additionalProperties": false
},
"bar": {
"properties": {"b": {"type": "number"}}
"properties": {"b": {"type": "number"}},
"additionalProperties": false
}
},
"type": "object"
"type": "object",
"additionalProperties": false
})""",
R"""(
alternative-0 ::= foo
alternative-1 ::= bar
bar ::= "{" space (bar-b-kv )? "}" space
bar-b-kv ::= "\"b\"" space ":" space number
decimal-part ::= [0-9]{1,16}
@ -1142,7 +1184,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
foo-a-kv ::= "\"a\"" space ":" space number
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= alternative-0 | alternative-1
root ::= foo | bar
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -1175,7 +1217,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"properties": {"d": {"type": "number"}}
}
},
"type": "object"
"additionalProperties": false
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
@ -1191,6 +1233,48 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
)"""
});
test({
SUCCESS,
"nested refs + mix of properties and allOf",
// Schema
R"""({
"properties": {
"common": {"$ref": "#/$defs/SomeVal"}
},
"allOf": [
{"$ref": "#/$defs/foo"},
{"$ref": "#/$defs/bar"},
{
"anyOf": [
{"$ref": "#/$defs/baz"},
{"$ref": "#/$defs/bam"}
]
}
],
"required": ["common"],
"$defs": {
"SomeVal": {"type": "number"},
"foo": {"properties": {"a": {"$ref": "#/$defs/SomeVal"}}},
"bar": {"properties": {"b": {"$ref": "#/$defs/SomeVal"}}},
"bam": {"properties": {"c": {"$ref": "#/$defs/SomeVal"}}},
"baz": {"properties": {"d": {"$ref": "#/$defs/SomeVal"}}}
}
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
b-kv ::= "\"b\"" space ":" space number
c-kv ::= "\"c\"" space ":" space number
common-kv ::= "\"common\"" space ":" space number
d-kv ::= "\"d\"" space ":" space number
d-rest ::= ( "," space c-kv )?
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space common-kv "," space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
space ::= | " " | "\n" [ \t]{0,20}
)""",
});
test({
SUCCESS,
"conflicting names",
@ -1272,7 +1356,7 @@ int main() {
test_all("JavaScript", [](const TestCase & tc) {
write("test-json-schema-input.tmp", tc.schema);
tc.verify_status(std::system(
"node ./tests/run-json-schema-to-grammar.mjs test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
"node ./examples/json_schema_to_grammar.mjs test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
tc.verify(read("test-grammar-output.tmp"));
});
} else {