--hf-repo without --hf-file
This commit is contained in:
parent
74f33adf5f
commit
8afc0f3784
2 changed files with 102 additions and 34 deletions
|
@ -193,18 +193,15 @@ int32_t cpu_get_num_math() {
|
|||
void gpt_params_handle_model_default(gpt_params & params) {
|
||||
if (!params.hf_repo.empty()) {
|
||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||
if (params.hf_file.empty()) {
|
||||
if (params.model.empty()) {
|
||||
throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
|
||||
}
|
||||
params.hf_file = params.model;
|
||||
} else if (params.model.empty()) {
|
||||
params.model_url = llama_get_hf_model_url(params.hf_repo, params.hf_file);
|
||||
if (params.model.empty()) {
|
||||
std::string cache_directory = fs_get_cache_directory();
|
||||
const bool success = fs_create_directory_with_parents(cache_directory);
|
||||
if (!success) {
|
||||
throw std::runtime_error("failed to create cache directory: " + cache_directory);
|
||||
}
|
||||
params.model = cache_directory + string_split(params.hf_file, '/').back();
|
||||
// TODO: cache with params.hf_repo in directory
|
||||
params.model = cache_directory + string_split(params.model_url, '/').back();
|
||||
}
|
||||
} else if (!params.model_url.empty()) {
|
||||
if (params.model.empty()) {
|
||||
|
@ -1888,9 +1885,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||
|
||||
llama_model * model = nullptr;
|
||||
|
||||
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
|
||||
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
|
||||
} else if (!params.model_url.empty()) {
|
||||
if (!params.model_url.empty()) {
|
||||
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
|
||||
} else {
|
||||
model = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||
|
@ -2061,6 +2056,16 @@ static bool starts_with(const std::string & str, const std::string & prefix) {
|
|||
return str.rfind(prefix, 0) == 0;
|
||||
}
|
||||
|
||||
static bool ends_with(const std::string & str, const std::string & suffix) {
|
||||
return str.rfind(suffix) == str.length() - suffix.length();
|
||||
}
|
||||
|
||||
static std::string tolower(std::string s) {
|
||||
std::transform(s.begin(), s.end(), s.begin(),
|
||||
[](unsigned char c){ return std::tolower(c); });
|
||||
return s;
|
||||
}
|
||||
|
||||
static bool llama_download_file(const std::string & url, const std::string & path) {
|
||||
|
||||
// Initialize libcurl
|
||||
|
@ -2341,26 +2346,91 @@ struct llama_model * llama_load_model_from_url(
|
|||
return llama_load_model_from_file(path_model, params);
|
||||
}
|
||||
|
||||
struct llama_model * llama_load_model_from_hf(
|
||||
const char * repo,
|
||||
const char * model,
|
||||
const char * path_model,
|
||||
const struct llama_model_params & params) {
|
||||
// construct hugging face model url:
|
||||
//
|
||||
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
|
||||
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
|
||||
//
|
||||
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
|
||||
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
|
||||
//
|
||||
static std::string llama_get_hf_model_url(
|
||||
std::string & repo,
|
||||
std::string & custom_file_path) {
|
||||
std::stringstream ss;
|
||||
json repo_files;
|
||||
|
||||
std::string model_url = "https://huggingface.co/";
|
||||
model_url += repo;
|
||||
model_url += "/resolve/main/";
|
||||
model_url += model;
|
||||
if (!custom_file_path.empty()) {
|
||||
ss << "https://huggingface.co/" << repo << "/resolve/main/" << custom_file_path;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
return llama_load_model_from_url(model_url.c_str(), path_model, params);
|
||||
{
|
||||
// Initialize libcurl
|
||||
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
|
||||
|
||||
// Make the request to Hub API
|
||||
ss << "https://huggingface.co/api/models/" << repo << "/tree/main?recursive=true";
|
||||
std::string url = ss.str();
|
||||
std::string res_str;
|
||||
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
|
||||
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
|
||||
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
|
||||
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
|
||||
return size * nmemb;
|
||||
};
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
|
||||
#if defined(_WIN32)
|
||||
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
||||
#endif
|
||||
CURLcode res = curl_easy_perform(curl.get());
|
||||
|
||||
if (res != CURLE_OK) {
|
||||
fprintf(stderr, "%s: cannot make GET request to Hugging Face Hub API\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
long res_code;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
|
||||
if (res_code != 200) {
|
||||
fprintf(stderr, "%s: Hugging Face Hub API responses with status code %ld\n", __func__, res_code);
|
||||
return nullptr;
|
||||
} else {
|
||||
repo_files = json::parse(res_str);
|
||||
}
|
||||
}
|
||||
|
||||
if (!repo_files.is_array()) {
|
||||
fprintf(stderr, "%s: response from Hugging Face Hub API is not an array\nRaw response:\n%s", __func__, repo_files.dump(4).c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto get_file_contains = [&](std::string piece) -> std::string {
|
||||
for (auto elem : repo_files) {
|
||||
std::string type = elem.at("type");
|
||||
std::string path = elem.at("path");
|
||||
if (
|
||||
type == "file"
|
||||
&& ends_with(path, ".gguf")
|
||||
&& tolower(path).find(piece) != std::string::npos
|
||||
) return path;
|
||||
}
|
||||
return "";
|
||||
};
|
||||
|
||||
std::string file_path = get_file_contains("q4_k_m");
|
||||
if (file_path.empty()) {
|
||||
file_path = get_file_contains("q4");
|
||||
}
|
||||
if (file_path.empty()) {
|
||||
file_path = get_file_contains("00001");
|
||||
}
|
||||
if (file_path.empty()) {
|
||||
file_path = get_file_contains("gguf");
|
||||
}
|
||||
|
||||
if (file_path.empty()) {
|
||||
fprintf(stderr, "%s: Cannot find any gguf file in the given repository", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ss = std::stringstream();
|
||||
ss << "https://huggingface.co/" << repo << "/resolve/main/" << file_path;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
#else
|
||||
|
@ -2373,11 +2443,9 @@ struct llama_model * llama_load_model_from_url(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
struct llama_model * llama_load_model_from_hf(
|
||||
const char * /*repo*/,
|
||||
const char * /*model*/,
|
||||
const char * /*path_model*/,
|
||||
const struct llama_model_params & /*params*/) {
|
||||
static std::string llama_get_hf_model_url(
|
||||
std::string & /*repo*/,
|
||||
std::string & /*custom_file_path*/) {
|
||||
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -223,7 +223,7 @@ struct llama_model_params llama_model_params_from_gpt_params (const gpt_param
|
|||
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
|
||||
|
||||
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
|
||||
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
|
||||
static std::string llama_get_hf_model_url(std::string & repo, std::string & custom_file_path);
|
||||
|
||||
// Batch utils
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue