Update json-schema-to-grammar.cpp

This commit is contained in:
ochafik 2024-03-12 02:06:48 +00:00
parent 8caaf1641d
commit 8fee84b45c

View file

@ -1,13 +1,14 @@
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include <algorithm>
#include <fstream>
#include <iostream>
#include <map>
#include <regex>
#include <sstream>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <iostream> #include <vector>
#include <regex>
#include <algorithm>
#include <sstream>
#include <fstream>
using json = nlohmann::json; using json = nlohmann::json;
using namespace std; using namespace std;
@ -21,8 +22,15 @@ unordered_map<string, string> PRIMITIVE_RULES = {
{"value", "object | array | string | number | boolean"}, {"value", "object | array | string | number | boolean"},
{"object", "\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space"}, {"object", "\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space"},
{"array", "\"[\" space ( value (\",\" space value)* )? \"]\" space"}, {"array", "\"[\" space ( value (\",\" space value)* )? \"]\" space"},
{"uuid", "\"\\\"\" \"-\" \"-\" \"-\" \"-\" \"\\\"\" space"}, {"uuid", "\"\\\"\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
{"string", "\"\\\"\" ([^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]))* \"\\\"\" space"}, "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] \"\\\"\" space"},
{"string", " \"\\\"\" (\n"
" [^\"\\\\] |\n"
" \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])\n"
" )* \"\\\"\" space"},
{"null", "\"null\" space"} {"null", "\"null\" space"}
}; };
@ -61,15 +69,15 @@ static std::vector<std::string> split(const std::string& str, const std::string&
std::vector<std::string> tokens; std::vector<std::string> tokens;
size_t start = 0; size_t start = 0;
size_t end = str.find(delimiter); size_t end = str.find(delimiter);
while (end != std::string::npos) { while (end != std::string::npos) {
tokens.push_back(str.substr(start, end - start)); tokens.push_back(str.substr(start, end - start));
start = end + delimiter.length(); start = end + delimiter.length();
end = str.find(delimiter, start); end = str.find(delimiter, start);
} }
tokens.push_back(str.substr(start)); tokens.push_back(str.substr(start));
return tokens; return tokens;
} }
@ -77,32 +85,32 @@ static string repeat(const string& str, size_t n) {
if (n == 0) { if (n == 0) {
return ""; return "";
} }
string result; string result;
result.reserve(str.length() * n); result.reserve(str.length() * n);
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
result += str; result += str;
} }
return result; return result;
} }
static std::string replacePattern(const std::string& input, const regex& regex, const function<string(const smatch &)>& replacement) { static std::string replacePattern(const std::string& input, const regex& regex, const function<string(const smatch &)>& replacement) {
std::smatch match; std::smatch match;
std::string result; std::string result;
std::string::const_iterator searchStart(input.cbegin()); std::string::const_iterator searchStart(input.cbegin());
std::string::const_iterator searchEnd(input.cend()); std::string::const_iterator searchEnd(input.cend());
while (std::regex_search(searchStart, searchEnd, match, regex)) { while (std::regex_search(searchStart, searchEnd, match, regex)) {
result.append(searchStart, searchStart + match.position()); result.append(searchStart, searchStart + match.position());
result.append(replacement(match)); result.append(replacement(match));
searchStart = match.suffix().first; searchStart = match.suffix().first;
} }
result.append(searchStart, searchEnd); result.append(searchStart, searchEnd);
return result; return result;
} }
@ -126,12 +134,12 @@ class SchemaConverter {
private: private:
std::optional<std::function<json(const string&)>> _fetch_json; std::optional<std::function<json(const string&)>> _fetch_json;
bool _dotall; bool _dotall;
unordered_map<string, string> _rules; map<string, string> _rules;
unordered_map<string, nlohmann::json> _refs; unordered_map<string, nlohmann::json> _refs;
unordered_set<string> _refs_being_resolved; unordered_set<string> _refs_being_resolved;
vector<string> _errors; vector<string> _errors;
vector<string> _warnings; vector<string> _warnings;
string _add_rule(const string& name, const string& rule) { string _add_rule(const string& name, const string& rule) {
string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-"); string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) { if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
@ -289,26 +297,34 @@ private:
seq.back().first = sub + "+"; seq.back().first = sub + "+";
} else { } else {
if (!sub_is_literal) { if (!sub_is_literal) {
string sub_id = sub_rule_ids[sub]; string& sub_id = sub_rule_ids[sub];
if (sub_id.empty()) { if (sub_id.empty()) {
sub_id = _add_rule(name + "-" + to_string(sub_rule_ids.size() + 1), sub); sub_id = _add_rule(name + "-" + to_string(sub_rule_ids.size()), sub);
sub_rule_ids[sub] = sub_id;
} }
sub = sub_id; sub = sub_id;
} }
string result; string result;
for (int j = 0; j < min_times; j++) { if (sub_is_literal && min_times > 0) {
if (sub_is_literal) { result = "\"" + repeat(sub.substr(1, sub.length() - 2), min_times) + "\"";
result += "\"" + repeat(sub.substr(1, sub.length() - 2), min_times) + "\""; } else {
} else { for (int j = 0; j < min_times; j++) {
result += sub + " "; if (j > 0) {
result += " ";
}
result += sub;
} }
} }
if (min_times > 0 && min_times < max_times) {
result += " ";
}
if (max_times == numeric_limits<int>::max()) { if (max_times == numeric_limits<int>::max()) {
result += sub + "*"; result += sub + "*";
} else { } else {
for (int j = min_times; j < max_times; j++) { for (int j = min_times; j < max_times; j++) {
result += sub + "? "; if (j > min_times) {
result += " ";
}
result += sub + "?";
} }
} }
seq.back().first = result; seq.back().first = result;
@ -417,12 +433,10 @@ private:
} }
rule += get_recursive_refs(vector<string>(optional_props.begin() + i, optional_props.end()), false); rule += get_recursive_refs(vector<string>(optional_props.begin() + i, optional_props.end()), false);
} }
rule += " ";
if (!required_props.empty()) { if (!required_props.empty()) {
rule += " ) "; rule += " )";
} }
rule += " )? "; rule += " )?";
} }
rule += " \"}\" space "; rule += " \"}\" space ";
@ -595,30 +609,25 @@ public:
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json(); json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : -1; int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : -1;
if (min_items > 0) { if (min_items > 0) {
successive_items = list_item_operator; successive_items += repeat(list_item_operator, min_items - 1);
for (int i = 1; i < min_items; i++) {
successive_items += list_item_operator;
}
min_items--; min_items--;
} }
if (max_items >= 0 && max_items > min_items) { if (max_items >= 0 && max_items > min_items) {
for (int i = min_items; i < max_items - 1; i++) { successive_items += repeat(list_item_operator + "?", max_items - min_items - 1);
successive_items += (list_item_operator + "?");
}
} else { } else {
successive_items += list_item_operator + "*"; successive_items += list_item_operator + "*";
} }
string rule; string rule;
if (min_items == 0) { if (min_items == 0) {
rule = "\"[\" space ( " + item_rule_name + " " + successive_items + " )? \"]\" space"; rule = "\"[\" space ( " + item_rule_name + " " + successive_items + " )? \"]\" space";
} else { } else {
rule = "\"[\" space " + item_rule_name + " " + successive_items + " \"]\" space"; rule = "\"[\" space " + item_rule_name + " " + successive_items + " \"]\" space";
} }
return _add_rule(rule_name, rule); return _add_rule(rule_name, rule);
} }
} else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) { } else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
return _visit_pattern(schema["pattern"], rule_name); return _visit_pattern(schema["pattern"], rule_name);
} else if ((schema_type == "object" || schema_type.is_null()) && (schema.size() == 1 || schema.empty())) { } else if (schema.empty() || (schema.size() == 1 && schema_type == "object")) {
for (const auto& [t, r] : PRIMITIVE_RULES) { for (const auto& [t, r] : PRIMITIVE_RULES) {
_add_rule(t, r); _add_rule(t, r);
} }
@ -667,3 +676,23 @@ string json_schema_to_grammar(const json& schema) {
converter.check_errors(); converter.check_errors();
return converter.format_grammar(); return converter.format_grammar();
} }
#ifdef LLAMA_BUILD_JSON_SCHEMA_CONVERTER
int main(int argc, const char** argv) {
if (argc != 2) {
cerr << "Expected only one argument" << endl;
return -1;
}
string file(argv[1]);
string schema;
if (file == "-") {
schema.append(istreambuf_iterator<char>(cin), istreambuf_iterator<char>());
} else {
ifstream in(argv[1]);
schema.append(istreambuf_iterator<char>(in), istreambuf_iterator<char>());
}
cout << json_schema_to_grammar(json::parse(schema)).c_str() << endl;
}
#endif