json: handle external $refs in C++ schema->grammar converter using libcurl
This commit is contained in:
parent
d63c953185
commit
9026acbdf2
1 changed files with 48 additions and 3 deletions
|
@ -8,7 +8,11 @@
|
|||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#if defined(LLAMA_USE_CURL)
|
||||
#include <curl/curl.h>
|
||||
#include <curl/easy.h>
|
||||
#endif
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
|
@ -817,7 +821,7 @@ public:
|
|||
|
||||
ResolvedRef _resolve_ref(const std::string & ref) {
|
||||
auto parts = split(ref, "#");
|
||||
if (parts.size() != 2) {
|
||||
if (parts.size() > 2) {
|
||||
_errors.push_back("Unsupported ref: " + ref);
|
||||
return {json(), "", false};
|
||||
}
|
||||
|
@ -840,6 +844,9 @@ public:
|
|||
_external_refs[url] = target;
|
||||
}
|
||||
}
|
||||
if (parts.size() == 1) {
|
||||
return {target, "", is_local};
|
||||
}
|
||||
auto tokens = split(parts[1], "/");
|
||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||
const auto & sel = tokens[i];
|
||||
|
@ -1040,8 +1047,46 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
static size_t json_schema_ref_curl_write_callback(char *ptr, size_t size, size_t nmemb, void *data) {
|
||||
auto &response = *static_cast<std::ostringstream *>(data);
|
||||
response.write((char *)ptr, size * nmemb);
|
||||
return size * nmemb;
|
||||
}
|
||||
|
||||
std::string json_schema_to_grammar(const json & schema) {
|
||||
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
|
||||
std::function<json(const std::string &)> fetch_json = [](const std::string &) { return json::object(); };
|
||||
|
||||
#if defined(LLAMA_USE_CURL)
|
||||
fetch_json = [](const std::string & url) {
|
||||
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
|
||||
if (!curl) {
|
||||
fprintf(stderr, "%s: error initializing libcurl\n", __func__);
|
||||
return json::object();
|
||||
}
|
||||
|
||||
std::ostringstream response;
|
||||
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
|
||||
|
||||
#if defined(_WIN32)
|
||||
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
||||
// operating system. Currently implemented under MS-Windows.
|
||||
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
||||
#endif
|
||||
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, &json_schema_ref_curl_write_callback);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response);
|
||||
CURLcode res = curl_easy_perform(curl.get());
|
||||
if (res != CURLE_OK) {
|
||||
throw std::runtime_error("Failed to fetch " + url + ": " + curl_easy_strerror(res));
|
||||
}
|
||||
response << '\0';
|
||||
return json::parse(response.str());
|
||||
};
|
||||
#endif
|
||||
|
||||
SchemaConverter converter(fetch_json, /* dotall= */ false);
|
||||
auto copy = schema;
|
||||
converter.visit(copy, "");
|
||||
converter.check_errors();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue