This commit is contained in:
Xuan Son Nguyen 2024-06-07 15:24:47 -05:00 committed by GitHub
commit ae93827c7f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 102 additions and 34 deletions

View file

@ -194,18 +194,15 @@ int32_t cpu_get_num_math() {
void gpt_params_handle_model_default(gpt_params & params) {
if (!params.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (params.hf_file.empty()) {
if (params.model.empty()) {
throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
}
params.hf_file = params.model;
} else if (params.model.empty()) {
params.model_url = llama_get_hf_model_url(params.hf_repo, params.hf_file);
if (params.model.empty()) {
std::string cache_directory = fs_get_cache_directory();
const bool success = fs_create_directory_with_parents(cache_directory);
if (!success) {
throw std::runtime_error("failed to create cache directory: " + cache_directory);
}
params.model = cache_directory + string_split(params.hf_file, '/').back();
// TODO: cache with params.hf_repo in directory
params.model = cache_directory + string_split(params.model_url, '/').back();
}
} else if (!params.model_url.empty()) {
if (params.model.empty()) {
@ -2279,9 +2276,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
llama_model * model = nullptr;
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
} else if (!params.model_url.empty()) {
if (!params.model_url.empty()) {
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
} else {
model = llama_load_model_from_file(params.model.c_str(), mparams);
@ -2452,6 +2447,16 @@ static bool starts_with(const std::string & str, const std::string & prefix) {
return str.rfind(prefix, 0) == 0;
}
static bool ends_with(const std::string & str, const std::string & suffix) {
return str.rfind(suffix) == str.length() - suffix.length();
}
static std::string tolower(std::string s) {
std::transform(s.begin(), s.end(), s.begin(),
[](unsigned char c){ return std::tolower(c); });
return s;
}
static bool llama_download_file(const std::string & url, const std::string & path) {
// Initialize libcurl
@ -2732,26 +2737,91 @@ struct llama_model * llama_load_model_from_url(
return llama_load_model_from_file(path_model, params);
}
struct llama_model * llama_load_model_from_hf(
const char * repo,
const char * model,
const char * path_model,
const struct llama_model_params & params) {
// construct hugging face model url:
//
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
//
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
//
static std::string llama_get_hf_model_url(
std::string & repo,
std::string & custom_file_path) {
std::stringstream ss;
json repo_files;
std::string model_url = "https://huggingface.co/";
model_url += repo;
model_url += "/resolve/main/";
model_url += model;
if (!custom_file_path.empty()) {
ss << "https://huggingface.co/" << repo << "/resolve/main/" << custom_file_path;
return ss.str();
}
return llama_load_model_from_url(model_url.c_str(), path_model, params);
{
// Initialize libcurl
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
// Make the request to Hub API
ss << "https://huggingface.co/api/models/" << repo << "/tree/main?recursive=true";
std::string url = ss.str();
std::string res_str;
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
return size * nmemb;
};
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
#if defined(_WIN32)
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
fprintf(stderr, "%s: cannot make GET request to Hugging Face Hub API\n", __func__);
return nullptr;
}
long res_code;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
if (res_code != 200) {
fprintf(stderr, "%s: Hugging Face Hub API responses with status code %ld\n", __func__, res_code);
return nullptr;
} else {
repo_files = json::parse(res_str);
}
}
if (!repo_files.is_array()) {
fprintf(stderr, "%s: response from Hugging Face Hub API is not an array\nRaw response:\n%s", __func__, repo_files.dump(4).c_str());
return nullptr;
}
auto get_file_contains = [&](std::string piece) -> std::string {
for (auto elem : repo_files) {
std::string type = elem.at("type");
std::string path = elem.at("path");
if (
type == "file"
&& ends_with(path, ".gguf")
&& tolower(path).find(piece) != std::string::npos
) return path;
}
return "";
};
std::string file_path = get_file_contains("q4_k_m");
if (file_path.empty()) {
file_path = get_file_contains("q4");
}
if (file_path.empty()) {
file_path = get_file_contains("00001");
}
if (file_path.empty()) {
file_path = get_file_contains("gguf");
}
if (file_path.empty()) {
fprintf(stderr, "%s: Cannot find any gguf file in the given repository", __func__);
return nullptr;
}
ss = std::stringstream();
ss << "https://huggingface.co/" << repo << "/resolve/main/" << file_path;
return ss.str();
}
#else
@ -2764,11 +2834,9 @@ struct llama_model * llama_load_model_from_url(
return nullptr;
}
struct llama_model * llama_load_model_from_hf(
const char * /*repo*/,
const char * /*model*/,
const char * /*path_model*/,
const struct llama_model_params & /*params*/) {
static std::string llama_get_hf_model_url(
std::string & /*repo*/,
std::string & /*custom_file_path*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return nullptr;
}

View file

@ -287,7 +287,7 @@ struct llama_model_params llama_model_params_from_gpt_params (const gpt_param
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
static std::string llama_get_hf_model_url(std::string & repo, std::string & custom_file_path);
// Batch utils