diff --git a/common/arg.cpp b/common/arg.cpp index 0c652da52..ede118d19 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -298,6 +298,27 @@ static void common_params_print_usage(common_params_context & ctx_arg) { print_options(specific_options); } +static std::vector parse_device_list(const std::string & value) { + std::vector devices; + auto dev_names = string_split(value, ','); + if (dev_names.empty()) { + throw std::invalid_argument("no devices specified"); + } + if (dev_names.size() == 1 && dev_names[0] == "none") { + devices.push_back(nullptr); + } else { + for (const auto & device : dev_names) { + auto * dev = ggml_backend_dev_by_name(device.c_str()); + if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) { + throw std::invalid_argument(string_format("invalid device: %s", device.c_str())); + } + devices.push_back(dev); + } + devices.push_back(nullptr); + } + return devices; +} + bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) { auto ctx_arg = common_params_parser_init(params, ex, print_usage); const common_params params_org = ctx_arg.params; // the example can modify the default params @@ -1314,21 +1335,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_env("LLAMA_ARG_NUMA")); add_opt(common_arg( {"-dev", "--device"}, "", - "comma-separated list of devices to use for offloading\n" + "comma-separated list of devices to use for offloading (none = don't offload)\n" "use --list-devices to see a list of available devices", [](common_params & params, const std::string & value) { - auto devices = string_split(value, ','); - if (devices.empty()) { - throw std::invalid_argument("no devices specified"); - } - for (const auto & device : devices) { - auto * dev = ggml_backend_dev_by_name(device.c_str()); - if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) { - throw std::invalid_argument(string_format("invalid device: %s", device.c_str())); - } - params.devices.push_back(dev); - } - params.devices.push_back(nullptr); + params.devices = parse_device_list(value); } ).set_env("LLAMA_ARG_DEVICES")); add_opt(common_arg( @@ -2074,21 +2084,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-devd", "--device-draft"}, "", - "comma-separated list of devices to use for offloading the draft model\n" + "comma-separated list of devices to use for offloading the draft model (none = don't offload)\n" "use --list-devices to see a list of available devices", [](common_params & params, const std::string & value) { - auto devices = string_split(value, ','); - if (devices.empty()) { - throw std::invalid_argument("no devices specified"); - } - for (const auto & device : devices) { - auto * dev = ggml_backend_dev_by_name(device.c_str()); - if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) { - throw std::invalid_argument(string_format("invalid device: %s", device.c_str())); - } - params.speculative.devices.push_back(dev); - } - params.speculative.devices.push_back(nullptr); + params.speculative.devices = parse_device_list(value); } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg(