diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 3e2a9772b..823db93a6 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -8,7 +8,11 @@ #include #include #include -#include + +#if defined(LLAMA_USE_CURL) +#include +#include +#endif using json = nlohmann::ordered_json; @@ -817,7 +821,7 @@ public: ResolvedRef _resolve_ref(const std::string & ref) { auto parts = split(ref, "#"); - if (parts.size() != 2) { + if (parts.size() > 2) { _errors.push_back("Unsupported ref: " + ref); return {json(), "", false}; } @@ -840,6 +844,9 @@ public: _external_refs[url] = target; } } + if (parts.size() == 1) { + return {target, "", is_local}; + } auto tokens = split(parts[1], "/"); for (size_t i = 1; i < tokens.size(); ++i) { const auto & sel = tokens[i]; @@ -1040,8 +1047,46 @@ public: } }; +static size_t json_schema_ref_curl_write_callback(char *ptr, size_t size, size_t nmemb, void *data) { + auto &response = *static_cast(data); + response.write((char *)ptr, size * nmemb); + return size * nmemb; +} + std::string json_schema_to_grammar(const json & schema) { - SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false); + std::function fetch_json = [](const std::string &) { return json::object(); }; + +#if defined(LLAMA_USE_CURL) + fetch_json = [](const std::string & url) { + std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); + if (!curl) { + fprintf(stderr, "%s: error initializing libcurl\n", __func__); + return json::object(); + } + + std::ostringstream response; + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + +#if defined(_WIN32) + // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of + // operating system. Currently implemented under MS-Windows. + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, &json_schema_ref_curl_write_callback); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response); + CURLcode res = curl_easy_perform(curl.get()); + if (res != CURLE_OK) { + throw std::runtime_error("Failed to fetch " + url + ": " + curl_easy_strerror(res)); + } + response << '\0'; + return json::parse(response.str()); + }; +#endif + + SchemaConverter converter(fetch_json, /* dotall= */ false); auto copy = schema; converter.visit(copy, ""); converter.check_errors();