common : support HF download for vocoder

This commit is contained in:
Georgi Gerganov 2024-12-18 19:22:56 +02:00
parent a95191c468
commit c0df192838
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 50 additions and 20 deletions

View file

@ -119,29 +119,33 @@ std::string common_arg::to_string() {
// utils // utils
// //
static void common_params_handle_model_default(common_params & params) { static void common_params_handle_model_default(
if (!params.hf_repo.empty()) { std::string & model,
std::string & model_url,
std::string & hf_repo,
std::string & hf_file) {
if (!hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model // short-hand to avoid specifying --hf-file -> default it to --model
if (params.hf_file.empty()) { if (hf_file.empty()) {
if (params.model.empty()) { if (model.empty()) {
throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
} }
params.hf_file = params.model; hf_file = model;
} else if (params.model.empty()) { } else if (model.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs // this is to avoid different repo having same file name, or same file name in different subdirs
std::string filename = params.hf_repo + "_" + params.hf_file; std::string filename = hf_repo + "_" + hf_file;
// to make sure we don't have any slashes in the filename // to make sure we don't have any slashes in the filename
string_replace_all(filename, "/", "_"); string_replace_all(filename, "/", "_");
params.model = fs_get_cache_file(filename); model = fs_get_cache_file(filename);
} }
} else if (!params.model_url.empty()) { } else if (!model_url.empty()) {
if (params.model.empty()) { if (model.empty()) {
auto f = string_split<std::string>(params.model_url, '#').front(); auto f = string_split<std::string>(model_url, '#').front();
f = string_split<std::string>(f, '?').front(); f = string_split<std::string>(f, '?').front();
params.model = fs_get_cache_file(string_split<std::string>(f, '/').back()); model = fs_get_cache_file(string_split<std::string>(f, '/').back());
} }
} else if (params.model.empty()) { } else if (model.empty()) {
params.model = DEFAULT_MODEL_PATH; model = DEFAULT_MODEL_PATH;
} }
} }
@ -276,7 +280,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
} }
common_params_handle_model_default(params); // TODO: refactor model params in a common struct
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file);
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file);
if (params.escape) { if (params.escape) {
string_process_escapes(params.prompt); string_process_escapes(params.prompt);
@ -1581,6 +1587,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.hf_file = value; params.hf_file = value;
} }
).set_env("LLAMA_ARG_HF_FILE")); ).set_env("LLAMA_ARG_HF_FILE"));
add_opt(common_arg(
{"-hfrv", "--hf-repo-v"}, "REPO",
"Hugging Face model repository for the vocoder model (default: unused)",
[](common_params & params, const std::string & value) {
params.vocoder.hf_repo = value;
}
).set_env("LLAMA_ARG_HF_REPO_V"));
add_opt(common_arg(
{"-hffv", "--hf-file-v"}, "FILE",
"Hugging Face model file for the vocoder model (default: unused)",
[](common_params & params, const std::string & value) {
params.vocoder.hf_file = value;
}
).set_env("LLAMA_ARG_HF_FILE_V"));
add_opt(common_arg( add_opt(common_arg(
{"-hft", "--hf-token"}, "TOKEN", {"-hft", "--hf-token"}, "TOKEN",
"Hugging Face access token (default: value from HF_TOKEN environment variable)", "Hugging Face access token (default: value from HF_TOKEN environment variable)",

View file

@ -1095,7 +1095,7 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
#define CURL_MAX_RETRY 3 #define CURL_MAX_RETRY 3
#define CURL_RETRY_DELAY_SECONDS 2 #define CURL_RETRY_DELAY_SECONDS 2
static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) { static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
int remaining_attempts = max_attempts; int remaining_attempts = max_attempts;
while (remaining_attempts > 0) { while (remaining_attempts > 0) {
@ -1119,7 +1119,6 @@ static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_
} }
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
// Initialize libcurl // Initialize libcurl
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup); std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
if (!curl) { if (!curl) {
@ -1192,11 +1191,13 @@ static bool common_download_file(const std::string & url, const std::string & pa
std::string etag; std::string etag;
std::string last_modified; std::string last_modified;
}; };
common_load_model_from_url_headers headers; common_load_model_from_url_headers headers;
{ {
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
common_load_model_from_url_headers *headers = (common_load_model_from_url_headers *) userdata; common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
static std::regex header_regex("([^:]+): (.*)\r\n"); static std::regex header_regex("([^:]+): (.*)\r\n");
static std::regex etag_regex("ETag", std::regex_constants::icase); static std::regex etag_regex("ETag", std::regex_constants::icase);

View file

@ -175,7 +175,11 @@ struct common_params_speculative {
}; };
struct common_params_vocoder { struct common_params_vocoder {
std::string model = ""; // vocoder model for producing audio // NOLINT std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT
}; };
struct common_params { struct common_params {

View file

@ -461,7 +461,12 @@ int main(int argc, char ** argv) {
model_ttc = llama_init_ttc.model; model_ttc = llama_init_ttc.model;
ctx_ttc = llama_init_ttc.context; ctx_ttc = llama_init_ttc.context;
// TODO: refactor in a common struct
params.model = params.vocoder.model; params.model = params.vocoder.model;
params.model_url = params.vocoder.model_url;
params.hf_repo = params.vocoder.hf_repo;
params.hf_file = params.vocoder.hf_file;
params.embedding = true; params.embedding = true;
common_init_result llama_init_cts = common_init_from_params(params); common_init_result llama_init_cts = common_init_from_params(params);