diff --git a/common/arg.cpp b/common/arg.cpp index 57b0d562a..126970950 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -133,7 +133,8 @@ static void common_params_handle_model_default( const std::string & model_url, std::string & hf_repo, std::string & hf_file, - const std::string & hf_token) { + const std::string & hf_token, + const std::string & model_default) { if (!hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model if (hf_file.empty()) { @@ -163,7 +164,7 @@ static void common_params_handle_model_default( model = fs_get_cache_file(string_split(f, '/').back()); } } else if (model.empty()) { - model = DEFAULT_MODEL_PATH; + model = model_default; } } @@ -299,9 +300,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context } // TODO: refactor model params in a common struct - common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token); - common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token); - common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token); + common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token, DEFAULT_MODEL_PATH); + common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, ""); + common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token, ""); if (params.escape) { string_process_escapes(params.prompt); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d1e8ee829..f35206d7b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1728,13 +1728,16 @@ struct server_context { add_bos_token = llama_vocab_get_add_bos(vocab); has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; - if (!params_base.speculative.model.empty()) { + if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) { SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); auto params_dft = params_base; params_dft.devices = params_base.speculative.devices; + params_dft.hf_file = params_base.speculative.hf_file; + params_dft.hf_repo = params_base.speculative.hf_repo; params_dft.model = params_base.speculative.model; + params_dft.model_url = params_base.speculative.model_url; params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_parallel = 1;