diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 61882f319..98392828d 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -392,6 +392,12 @@ static std::string format_literal(const std::string & literal) { return "\"" + escaped + "\""; } +static size_t json_schema_ref_curl_write_callback(char *ptr, size_t size, size_t nmemb, void *data) { + auto &response = *static_cast(data); + response.write((char *)ptr, size * nmemb); + return size * nmemb; +} + class SchemaConverter { private: std::function _fetch_json; @@ -801,12 +807,50 @@ private: } public: - SchemaConverter( - const std::function & fetch_json, - bool dotall) - : _fetch_json(fetch_json), _dotall(dotall) + SchemaConverter(bool dotall) : _dotall(dotall) { _rules["space"] = SPACE_RULE; + +#if defined(LLAMA_USE_CURL) + _fetch_json = [&](const std::string & url) { + // TODO: implement HTTP caching semantics. + static std::unordered_map cache; + auto it = cache.find(url); + if (it != cache.end()) { + return it->second; + } + std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); + if (!curl) { + fprintf(stderr, "%s: error initializing libcurl\n", __func__); + return json::object(); + } + + std::ostringstream response; + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + + #if defined(_WIN32) + // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of + // operating system. Currently implemented under MS-Windows. + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); + #endif + + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, &json_schema_ref_curl_write_callback); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response); + CURLcode res = curl_easy_perform(curl.get()); + if (res != CURLE_OK) { + throw std::runtime_error("Failed to fetch " + url + ": " + curl_easy_strerror(res)); + } + response << '\0'; + return cache[url] = json::parse(response.str()); + }; +#else + fetch_json = [](const std::string &) { + _errors.push_back("Fetching external refs not supported, please recompile with CURL support."); + return json::object(); + }; +#endif } std::string _generate_constant_rule(const json & value) { @@ -1001,7 +1045,7 @@ public: } } } else { - // todo warning + _warnings.push_back("Unsupported allOf schema"); } }; for (auto & t : schema["allOf"]) { @@ -1050,53 +1094,8 @@ public: } }; -static size_t json_schema_ref_curl_write_callback(char *ptr, size_t size, size_t nmemb, void *data) { - auto &response = *static_cast(data); - response.write((char *)ptr, size * nmemb); - return size * nmemb; -} - std::string json_schema_to_grammar(const json & schema) { - std::function fetch_json = [](const std::string &) { return json::object(); }; - -#if defined(LLAMA_USE_CURL) - // TODO: implement HTTP caching semantics. - std::unordered_map cache; - - fetch_json = [&](const std::string & url) { - auto it = cache.find(url); - if (it != cache.end()) { - return it->second; - } - std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); - if (!curl) { - fprintf(stderr, "%s: error initializing libcurl\n", __func__); - return json::object(); - } - - std::ostringstream response; - curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); - -#if defined(_WIN32) - // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of - // operating system. Currently implemented under MS-Windows. - curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif - - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, &json_schema_ref_curl_write_callback); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response); - CURLcode res = curl_easy_perform(curl.get()); - if (res != CURLE_OK) { - throw std::runtime_error("Failed to fetch " + url + ": " + curl_easy_strerror(res)); - } - response << '\0'; - return cache[url] = json::parse(response.str()); - }; -#endif - - SchemaConverter converter(fetch_json, /* dotall= */ false); + SchemaConverter converter(/* dotall= */ false); auto copy = schema; converter.visit(copy, ""); converter.check_errors();