From 8fee84b45cfeac573a49733ca604c1a1d7a9fea4 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 12 Mar 2024 02:06:48 +0000 Subject: [PATCH] Update json-schema-to-grammar.cpp --- examples/server/json-schema-to-grammar.cpp | 115 +++++++++++++-------- 1 file changed, 72 insertions(+), 43 deletions(-) diff --git a/examples/server/json-schema-to-grammar.cpp b/examples/server/json-schema-to-grammar.cpp index 36ef5d6e6..31f321b5d 100644 --- a/examples/server/json-schema-to-grammar.cpp +++ b/examples/server/json-schema-to-grammar.cpp @@ -1,13 +1,14 @@ #include "json-schema-to-grammar.h" +#include +#include +#include +#include +#include +#include #include -#include #include #include -#include -#include -#include -#include -#include +#include using json = nlohmann::json; using namespace std; @@ -21,8 +22,15 @@ unordered_map PRIMITIVE_RULES = { {"value", "object | array | string | number | boolean"}, {"object", "\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space"}, {"array", "\"[\" space ( value (\",\" space value)* )? \"]\" space"}, - {"uuid", "\"\\\"\" \"-\" \"-\" \"-\" \"-\" \"\\\"\" space"}, - {"string", "\"\\\"\" ([^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]))* \"\\\"\" space"}, + {"uuid", "\"\\\"\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " + "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " + "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " + "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " + "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] \"\\\"\" space"}, + {"string", " \"\\\"\" (\n" + " [^\"\\\\] |\n" + " \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])\n" + " )* \"\\\"\" space"}, {"null", "\"null\" space"} }; @@ -61,15 +69,15 @@ static std::vector split(const std::string& str, const std::string& std::vector tokens; size_t start = 0; size_t end = str.find(delimiter); - + while (end != std::string::npos) { tokens.push_back(str.substr(start, end - start)); start = end + delimiter.length(); end = str.find(delimiter, start); } - + tokens.push_back(str.substr(start)); - + return tokens; } @@ -77,32 +85,32 @@ static string repeat(const string& str, size_t n) { if (n == 0) { return ""; } - + string result; result.reserve(str.length() * n); - + for (size_t i = 0; i < n; ++i) { result += str; } - + return result; } static std::string replacePattern(const std::string& input, const regex& regex, const function& replacement) { std::smatch match; std::string result; - + std::string::const_iterator searchStart(input.cbegin()); std::string::const_iterator searchEnd(input.cend()); - + while (std::regex_search(searchStart, searchEnd, match, regex)) { result.append(searchStart, searchStart + match.position()); result.append(replacement(match)); searchStart = match.suffix().first; } - + result.append(searchStart, searchEnd); - + return result; } @@ -126,12 +134,12 @@ class SchemaConverter { private: std::optional> _fetch_json; bool _dotall; - unordered_map _rules; + map _rules; unordered_map _refs; unordered_set _refs_being_resolved; vector _errors; vector _warnings; - + string _add_rule(const string& name, const string& rule) { string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-"); if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) { @@ -289,26 +297,34 @@ private: seq.back().first = sub + "+"; } else { if (!sub_is_literal) { - string sub_id = sub_rule_ids[sub]; + string& sub_id = sub_rule_ids[sub]; if (sub_id.empty()) { - sub_id = _add_rule(name + "-" + to_string(sub_rule_ids.size() + 1), sub); - sub_rule_ids[sub] = sub_id; + sub_id = _add_rule(name + "-" + to_string(sub_rule_ids.size()), sub); } sub = sub_id; } string result; - for (int j = 0; j < min_times; j++) { - if (sub_is_literal) { - result += "\"" + repeat(sub.substr(1, sub.length() - 2), min_times) + "\""; - } else { - result += sub + " "; + if (sub_is_literal && min_times > 0) { + result = "\"" + repeat(sub.substr(1, sub.length() - 2), min_times) + "\""; + } else { + for (int j = 0; j < min_times; j++) { + if (j > 0) { + result += " "; + } + result += sub; } } + if (min_times > 0 && min_times < max_times) { + result += " "; + } if (max_times == numeric_limits::max()) { result += sub + "*"; } else { for (int j = min_times; j < max_times; j++) { - result += sub + "? "; + if (j > min_times) { + result += " "; + } + result += sub + "?"; } } seq.back().first = result; @@ -417,12 +433,10 @@ private: } rule += get_recursive_refs(vector(optional_props.begin() + i, optional_props.end()), false); } - rule += " "; - if (!required_props.empty()) { - rule += " ) "; + rule += " )"; } - rule += " )? "; + rule += " )?"; } rule += " \"}\" space "; @@ -595,30 +609,25 @@ public: json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json(); int max_items = max_items_json.is_number_integer() ? max_items_json.get() : -1; if (min_items > 0) { - successive_items = list_item_operator; - for (int i = 1; i < min_items; i++) { - successive_items += list_item_operator; - } + successive_items += repeat(list_item_operator, min_items - 1); min_items--; } if (max_items >= 0 && max_items > min_items) { - for (int i = min_items; i < max_items - 1; i++) { - successive_items += (list_item_operator + "?"); - } + successive_items += repeat(list_item_operator + "?", max_items - min_items - 1); } else { successive_items += list_item_operator + "*"; } string rule; if (min_items == 0) { - rule = "\"[\" space ( " + item_rule_name + " " + successive_items + " )? \"]\" space"; + rule = "\"[\" space ( " + item_rule_name + " " + successive_items + " )? \"]\" space"; } else { - rule = "\"[\" space " + item_rule_name + " " + successive_items + " \"]\" space"; + rule = "\"[\" space " + item_rule_name + " " + successive_items + " \"]\" space"; } return _add_rule(rule_name, rule); } } else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) { return _visit_pattern(schema["pattern"], rule_name); - } else if ((schema_type == "object" || schema_type.is_null()) && (schema.size() == 1 || schema.empty())) { + } else if (schema.empty() || (schema.size() == 1 && schema_type == "object")) { for (const auto& [t, r] : PRIMITIVE_RULES) { _add_rule(t, r); } @@ -667,3 +676,23 @@ string json_schema_to_grammar(const json& schema) { converter.check_errors(); return converter.format_grammar(); } + +#ifdef LLAMA_BUILD_JSON_SCHEMA_CONVERTER + +int main(int argc, const char** argv) { + if (argc != 2) { + cerr << "Expected only one argument" << endl; + return -1; + } + string file(argv[1]); + string schema; + if (file == "-") { + schema.append(istreambuf_iterator(cin), istreambuf_iterator()); + } else { + ifstream in(argv[1]); + schema.append(istreambuf_iterator(in), istreambuf_iterator()); + } + cout << json_schema_to_grammar(json::parse(schema)).c_str() << endl; +} + +#endif