diff --git a/common/common.cpp b/common/common.cpp index 2e7ddce12..8cede30b0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1310,6 +1310,29 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return false; } +void gpt_params_handle_model_default(gpt_params & params) { + if (!params.hf_repo.empty()) { + // short-hand to avoid specifying --hf-file -> default it to --model + if (params.hf_file.empty()) { + if (params.model.empty()) { + throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); + } + params.hf_file = params.model; + } else if (params.model.empty()) { + params.model = "models/" + string_split(params.hf_file, '/').back(); + } + } else if (!params.model_url.empty()) { + if (params.model.empty()) { + auto f = string_split(params.model_url, '#').front(); + f = string_split(f, '?').front(); + f = string_split(f, '/').back(); + params.model = "models/" + f; + } + } else if (params.model.empty()) { + params.model = "models/7B/ggml-model-f16.gguf"; + } +} + bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { bool invalid_param = false; std::string arg; @@ -1338,26 +1361,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); } - if (!params.hf_repo.empty()) { - // short-hand to avoid specifying --hf-file -> default it to --model - if (params.hf_file.empty()) { - if (params.model.empty()) { - throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); - } - params.hf_file = params.model; - } else if (params.model.empty()) { - params.model = "models/" + string_split(params.hf_file, '/').back(); - } - } else if (!params.model_url.empty()) { - if (params.model.empty()) { - auto f = string_split(params.model_url, '#').front(); - f = string_split(f, '?').front(); - f = string_split(f, '/').back(); - params.model = "models/" + f; - } - } else if (params.model.empty()) { - params.model = "models/7B/ggml-model-f16.gguf"; - } + gpt_params_handle_model_default(params); if (params.escape) { process_escapes(params.prompt); diff --git a/common/common.h b/common/common.h index ff0eed055..d828c3e7d 100644 --- a/common/common.h +++ b/common/common.h @@ -170,6 +170,8 @@ struct gpt_params { std::string image = ""; // path to an image file }; +void gpt_params_handle_model_default(gpt_params & params); + bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3acbd17df..cf57e96e9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2846,6 +2846,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, } } + gpt_params_handle_model_default(params); + if (!params.kv_overrides.empty()) { params.kv_overrides.emplace_back(); params.kv_overrides.back().key[0] = 0;