From 6ef22f05472d061d467301604ad7078f1e7e428a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 20 Jan 2025 21:46:58 +0200 Subject: [PATCH] common : add -hfd option for the draft model --- common/arg.cpp | 12 ++++++++++-- common/common.h | 8 +++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index dede335fb..a72c90bfa 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -299,8 +299,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context } // TODO: refactor model params in a common struct - common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token); - common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token); + common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token); + common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token); + common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token); if (params.escape) { string_process_escapes(params.prompt); @@ -1629,6 +1630,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.hf_repo = value; } ).set_env("LLAMA_ARG_HF_REPO")); + add_opt(common_arg( + {"-hfd", "-hfrd", "--hf-repo-draft"}, "/[:quant]", + "Same as --hf-repo, but for the draft model (default: unused)", + [](common_params & params, const std::string & value) { + params.speculative.hf_repo = value; + } + ).set_env("LLAMA_ARG_HF_REPO")); add_opt(common_arg( {"-hff", "--hf-file"}, "FILE", "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)", diff --git a/common/common.h b/common/common.h index 3bcc637cc..b2709c044 100644 --- a/common/common.h +++ b/common/common.h @@ -175,7 +175,11 @@ struct common_params_speculative { struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - std::string model = ""; // draft model for speculative decoding // NOLINT + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT + + std::string model = ""; // draft model for speculative decoding // NOLINT + std::string model_url = ""; // model url to download // NOLINT }; struct common_params_vocoder { @@ -508,12 +512,14 @@ struct llama_model * common_load_model_from_url( const std::string & local_path, const std::string & hf_token, const struct llama_model_params & params); + struct llama_model * common_load_model_from_hf( const std::string & repo, const std::string & remote_path, const std::string & local_path, const std::string & hf_token, const struct llama_model_params & params); + std::pair common_get_hf_file( const std::string & hf_repo_with_tag, const std::string & hf_token);