From 803031665afc6d0a7d7391693f045c4d0051eb95 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 11 Jan 2025 19:44:12 +0100 Subject: [PATCH] common : support tag-based hf_repo like on ollama --- common/arg.cpp | 115 ++++++++++++++++++++++++++++++++++++++++++---- common/common.cpp | 9 ++-- common/common.h | 5 ++ 3 files changed, 114 insertions(+), 15 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 27886b84e..112a0dc3e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -13,6 +13,12 @@ #include #include +#if defined(LLAMA_USE_CURL) +#include +#include +#include +#endif + #include "json-schema-to-grammar.h" using json = nlohmann::ordered_json; @@ -128,18 +134,105 @@ std::string common_arg::to_string() { // utils // +#if defined(LLAMA_USE_CURL) +/** + * Allow getting the HF file from the HF repo with tag (like ollama), for example: + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 + * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s + * Tag is optional, default to Q4_K_M if it exists + * Return pair of (with "repo" already having tag removed) + */ +static std::pair common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) { + auto parts = string_split(hf_repo_with_tag, ':'); + std::string tag = parts.size() > 1 ? parts[1] : "latest"; // "latest" means checking Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo + std::string hf_repo = parts[0]; + if (string_split(hf_repo, '/').size() != 2) { + throw std::invalid_argument("error: invalid HF repo format, expected /[:tag]\n"); + } + + // fetch model info from Hugging Face Hub API + json model_info; + std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); + std::unique_ptr http_headers(nullptr, &curl_slist_free_all); + std::string res_str; + std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag; + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); + auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { + static_cast(data)->append((char * ) ptr, size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + if (!hf_token.empty()) { + std::string auth_header = "Authorization: Bearer " + hf_token; + http_headers.reset(curl_slist_append(http_headers.get(), auth_header.c_str())); + // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response + http_headers.reset(curl_slist_append(http_headers.get(), "User-Agent: llama-cpp")); + http_headers.reset(curl_slist_append(http_headers.get(), "Accept: application/json")); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.get()); + } + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + throw std::runtime_error("error: cannot make GET request to Hugging Face Hub API"); + } + + long res_code; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + if (res_code == 200) { + model_info = json::parse(res_str); + } if (res_code == 401) { + throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); + } else { + throw std::runtime_error(string_format("error: cannot get model info from Hugging Face Hub API, response code: %ld", res_code)); + } + + // check response + if (!model_info.contains("ggufFile")) { + throw std::runtime_error("error: model does not have ggufFile"); + } + json & gguf_file = model_info.at("ggufFile"); + if (!gguf_file.contains("rfilename")) { + throw std::runtime_error("error: ggufFile does not have rfilename"); + } + + // TODO handle error + return std::make_pair(hf_repo, gguf_file.at("rfilename")); +} +#else +static std::string common_get_hf_file(const std::string &, const std::string &) { + throw std::runtime_error("error: llama.cpp built without libcurl"); +} +#endif + static void common_params_handle_model_default( std::string & model, - std::string & model_url, + const std::string & model_url, std::string & hf_repo, - std::string & hf_file) { + std::string & hf_file, + const std::string & hf_token) { if (!hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model if (hf_file.empty()) { if (model.empty()) { - throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); + try { + auto auto_detected = common_get_hf_file(hf_repo, hf_token); + hf_repo = auto_detected.first; + hf_file = auto_detected.second; + printf("%s: using hf_file = %s\n", __func__, hf_file.c_str()); + } catch (std::exception & e) { + fprintf(stderr, "%s: %s\n", __func__, e.what()); + exit(1); + } + } else { + hf_file = model; } - hf_file = model; } else if (model.empty()) { // this is to avoid different repo having same file name, or same file name in different subdirs std::string filename = hf_repo + "_" + hf_file; @@ -290,8 +383,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context } // TODO: refactor model params in a common struct - common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file); - common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file); + common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token); + common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token); if (params.escape) { string_process_escapes(params.prompt); @@ -1583,21 +1676,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_MODEL_URL")); add_opt(common_arg( - {"-hfr", "--hf-repo"}, "REPO", - "Hugging Face model repository (default: unused)", + {"-hf", "-hfr", "--hf-repo"}, "/[:quant]", + "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n" + "example: unsloth/phi-4-GGUF:q4_k_m\n" + "(default: unused)", [](common_params & params, const std::string & value) { params.hf_repo = value; } ).set_env("LLAMA_ARG_HF_REPO")); add_opt(common_arg( {"-hff", "--hf-file"}, "FILE", - "Hugging Face model file (default: unused)", + "Hugging Face model file, unused if quant is already specified in --hf-repo (default: unused)", [](common_params & params, const std::string & value) { params.hf_file = value; } ).set_env("LLAMA_ARG_HF_FILE")); add_opt(common_arg( - {"-hfrv", "--hf-repo-v"}, "REPO", + {"-hfv", "-hfrv", "--hf-repo-v"}, "/[:quant]", "Hugging Face model repository for the vocoder model (default: unused)", [](common_params & params, const std::string & value) { params.vocoder.hf_repo = value; diff --git a/common/common.cpp b/common/common.cpp index 86e4e1e24..dca7ddf69 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1127,6 +1127,7 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { // Initialize libcurl std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); + std::unique_ptr http_headers(nullptr, &curl_slist_free_all); if (!curl) { LOG_ERR("%s: error initializing libcurl\n", __func__); return false; @@ -1140,11 +1141,9 @@ static bool common_download_file(const std::string & url, const std::string & pa // Check if hf-token or bearer-token was specified if (!hf_token.empty()) { - std::string auth_header = "Authorization: Bearer "; - auth_header += hf_token.c_str(); - struct curl_slist *http_headers = NULL; - http_headers = curl_slist_append(http_headers, auth_header.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers); + std::string auth_header = "Authorization: Bearer " + hf_token; + http_headers.reset(curl_slist_append(http_headers.get(), auth_header.c_str())); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.get()); } #if defined(_WIN32) diff --git a/common/common.h b/common/common.h index 0d452cf0f..42d75ef4b 100644 --- a/common/common.h +++ b/common/common.h @@ -454,6 +454,11 @@ static bool string_starts_with(const std::string & str, return str.rfind(prefix, 0) == 0; } +static bool string_ends_with(const std::string & str, + const std::string & suffix) { // While we wait for C++20's std::string::ends_with... + return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; +} + bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input);