From 0de79f7316569f12d95614de0ed2e13f7ccabd69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?DAN=E2=84=A2?= Date: Sat, 16 Mar 2024 10:00:46 -0400 Subject: [PATCH] Refactor nested if causing error C1061 on MSVC. --- common/common.cpp | 1734 +++++++++++++++++++++++++-------------------- 1 file changed, 962 insertions(+), 772 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 4912237e0..09d6c9cf6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -139,6 +139,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { return result; } +constexpr std::uint64_t constexpr_hash(const char* input) { + std::uint64_t hash = 0xcbf29ce484222325; + const std::uint64_t prime = 0x100000001b3; + + for (; *input; ++input) { + hash ^= static_cast(*input); + hash *= prime; + } + + return hash; +} + bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { bool invalid_param = false; std::string arg; @@ -151,786 +163,964 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { std::replace(arg.begin(), arg.end(), '_', '-'); } - if (arg == "-s" || arg == "--seed") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.seed = std::stoul(argv[i]); - } else if (arg == "-t" || arg == "--threads") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_threads = std::stoi(argv[i]); - if (params.n_threads <= 0) { - params.n_threads = std::thread::hardware_concurrency(); - } - } else if (arg == "-tb" || arg == "--threads-batch") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_threads_batch = std::stoi(argv[i]); - if (params.n_threads_batch <= 0) { - params.n_threads_batch = std::thread::hardware_concurrency(); - } - } else if (arg == "-td" || arg == "--threads-draft") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_threads_draft = std::stoi(argv[i]); - if (params.n_threads_draft <= 0) { - params.n_threads_draft = std::thread::hardware_concurrency(); - } - } else if (arg == "-tbd" || arg == "--threads-batch-draft") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_threads_batch_draft = std::stoi(argv[i]); - if (params.n_threads_batch_draft <= 0) { - params.n_threads_batch_draft = std::thread::hardware_concurrency(); - } - } else if (arg == "-p" || arg == "--prompt") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.prompt = argv[i]; - } else if (arg == "-e" || arg == "--escape") { - params.escape = true; - } else if (arg == "--prompt-cache") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.path_prompt_cache = argv[i]; - } else if (arg == "--prompt-cache-all") { - params.prompt_cache_all = true; - } else if (arg == "--prompt-cache-ro") { - params.prompt_cache_ro = true; - } else if (arg == "-bf" || arg == "--binary-file") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::ifstream file(argv[i], std::ios::binary); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - break; - } - // store the external file name in params - params.prompt_file = argv[i]; - std::ostringstream ss; - ss << file.rdbuf(); - params.prompt = ss.str(); - fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), argv[i]); - } else if (arg == "-f" || arg == "--file") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - break; - } - // store the external file name in params - params.prompt_file = argv[i]; - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); - if (!params.prompt.empty() && params.prompt.back() == '\n') { - params.prompt.pop_back(); - } - } else if (arg == "-n" || arg == "--n-predict") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_predict = std::stoi(argv[i]); - } else if (arg == "--top-k") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.top_k = std::stoi(argv[i]); - } else if (arg == "-c" || arg == "--ctx-size") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_ctx = std::stoi(argv[i]); - } else if (arg == "--grp-attn-n" || arg == "-gan") { - if (++i >= argc) { - invalid_param = true; - break; - } - - params.grp_attn_n = std::stoi(argv[i]); - } else if (arg == "--grp-attn-w" || arg == "-gaw") { - if (++i >= argc) { - invalid_param = true; - break; - } - - params.grp_attn_w = std::stoi(argv[i]); - } else if (arg == "--rope-freq-base") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.rope_freq_base = std::stof(argv[i]); - } else if (arg == "--rope-freq-scale") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.rope_freq_scale = std::stof(argv[i]); - } else if (arg == "--rope-scaling") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::string value(argv[i]); - /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } - else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } - else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } - else { invalid_param = true; break; } - } else if (arg == "--rope-scale") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.rope_freq_scale = 1.0f/std::stof(argv[i]); - } else if (arg == "--yarn-orig-ctx") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.yarn_orig_ctx = std::stoi(argv[i]); - } else if (arg == "--yarn-ext-factor") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.yarn_ext_factor = std::stof(argv[i]); - } else if (arg == "--yarn-attn-factor") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.yarn_attn_factor = std::stof(argv[i]); - } else if (arg == "--yarn-beta-fast") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.yarn_beta_fast = std::stof(argv[i]); - } else if (arg == "--yarn-beta-slow") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.yarn_beta_slow = std::stof(argv[i]); - } else if (arg == "--pooling") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::string value(argv[i]); - /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } - else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } - else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } - else { invalid_param = true; break; } - } else if (arg == "--defrag-thold" || arg == "-dt") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.defrag_thold = std::stof(argv[i]); - } else if (arg == "--samplers") { - if (++i >= argc) { - invalid_param = true; - break; - } - const auto sampler_names = string_split(argv[i], ';'); - sparams.samplers_sequence = sampler_types_from_names(sampler_names, true); - } else if (arg == "--sampling-seq") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.samplers_sequence = sampler_types_from_chars(argv[i]); - } else if (arg == "--top-p") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.top_p = std::stof(argv[i]); - } else if (arg == "--min-p") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.min_p = std::stof(argv[i]); - } else if (arg == "--temp") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.temp = std::stof(argv[i]); - sparams.temp = std::max(sparams.temp, 0.0f); - } else if (arg == "--tfs") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.tfs_z = std::stof(argv[i]); - } else if (arg == "--typical") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.typical_p = std::stof(argv[i]); - } else if (arg == "--repeat-last-n") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.penalty_last_n = std::stoi(argv[i]); - sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); - } else if (arg == "--repeat-penalty") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.penalty_repeat = std::stof(argv[i]); - } else if (arg == "--frequency-penalty") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.penalty_freq = std::stof(argv[i]); - } else if (arg == "--presence-penalty") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.penalty_present = std::stof(argv[i]); - } else if (arg == "--dynatemp-range") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.dynatemp_range = std::stof(argv[i]); - } else if (arg == "--dynatemp-exp") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.dynatemp_exponent = std::stof(argv[i]); - } else if (arg == "--mirostat") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.mirostat = std::stoi(argv[i]); - } else if (arg == "--mirostat-lr") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.mirostat_eta = std::stof(argv[i]); - } else if (arg == "--mirostat-ent") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.mirostat_tau = std::stof(argv[i]); - } else if (arg == "--cfg-negative-prompt") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.cfg_negative_prompt = argv[i]; - } else if (arg == "--cfg-negative-prompt-file") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - break; - } - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(sparams.cfg_negative_prompt)); - if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') { - sparams.cfg_negative_prompt.pop_back(); - } - } else if (arg == "--cfg-scale") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.cfg_scale = std::stof(argv[i]); - } else if (arg == "-b" || arg == "--batch-size") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_batch = std::stoi(argv[i]); - } else if (arg == "-ub" || arg == "--ubatch-size") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_ubatch = std::stoi(argv[i]); - } else if (arg == "--keep") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_keep = std::stoi(argv[i]); - } else if (arg == "--draft") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_draft = std::stoi(argv[i]); - } else if (arg == "--chunks") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_chunks = std::stoi(argv[i]); - } else if (arg == "-np" || arg == "--parallel") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_parallel = std::stoi(argv[i]); - } else if (arg == "-ns" || arg == "--sequences") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_sequences = std::stoi(argv[i]); - } else if (arg == "--p-split" || arg == "-ps") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.p_split = std::stof(argv[i]); - } else if (arg == "-m" || arg == "--model") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.model = argv[i]; - } else if (arg == "-md" || arg == "--model-draft") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.model_draft = argv[i]; - } else if (arg == "-a" || arg == "--alias") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.model_alias = argv[i]; - } else if (arg == "--lora") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.lora_adapter.emplace_back(argv[i], 1.0f); - params.use_mmap = false; - } else if (arg == "--lora-scaled") { - if (++i >= argc) { - invalid_param = true; - break; - } - const char * lora_adapter = argv[i]; - if (++i >= argc) { - invalid_param = true; - break; - } - params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); - params.use_mmap = false; - } else if (arg == "--lora-base") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.lora_base = argv[i]; - } else if (arg == "--control-vector") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.control_vectors.push_back({ 1.0f, argv[i], }); - } else if (arg == "--control-vector-scaled") { - if (++i >= argc) { - invalid_param = true; - break; - } - const char * fname = argv[i]; - if (++i >= argc) { - invalid_param = true; - break; - } - params.control_vectors.push_back({ std::stof(argv[i]), fname, }); - } else if (arg == "--control-vector-layer-range") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.control_vector_layer_start = std::stoi(argv[i]); - if (++i >= argc) { - invalid_param = true; - break; - } - params.control_vector_layer_end = std::stoi(argv[i]); - } else if (arg == "--mmproj") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.mmproj = argv[i]; - } else if (arg == "--image") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.image = argv[i]; - } else if (arg == "-i" || arg == "--interactive") { - params.interactive = true; - } else if (arg == "--embedding") { - params.embedding = true; - } else if (arg == "--interactive-first") { - params.interactive_first = true; - } else if (arg == "-ins" || arg == "--instruct") { - params.instruct = true; - } else if (arg == "-cml" || arg == "--chatml") { - params.chatml = true; - } else if (arg == "--infill") { - params.infill = true; - } else if (arg == "-dkvc" || arg == "--dump-kv-cache") { - params.dump_kv_cache = true; - } else if (arg == "-nkvo" || arg == "--no-kv-offload") { - params.no_kv_offload = true; - } else if (arg == "-ctk" || arg == "--cache-type-k") { - params.cache_type_k = argv[++i]; - } else if (arg == "-ctv" || arg == "--cache-type-v") { - params.cache_type_v = argv[++i]; - } else if (arg == "--multiline-input") { - params.multiline_input = true; - } else if (arg == "--simple-io") { - params.simple_io = true; - } else if (arg == "-cb" || arg == "--cont-batching") { - params.cont_batching = true; - } else if (arg == "--color") { - params.use_color = true; - } else if (arg == "--mlock") { - params.use_mlock = true; - } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_gpu_layers = std::stoi(argv[i]); - if (!llama_supports_gpu_offload()) { - fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); - } - } else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_gpu_layers_draft = std::stoi(argv[i]); - if (!llama_supports_gpu_offload()) { - fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); - } - } else if (arg == "--main-gpu" || arg == "-mg") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.main_gpu = std::stoi(argv[i]); -#ifndef GGML_USE_CUBLAS_SYCL - fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting the main GPU has no effect.\n"); -#endif // GGML_USE_CUBLAS_SYCL - } else if (arg == "--split-mode" || arg == "-sm") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::string arg_next = argv[i]; - if (arg_next == "none") { - params.split_mode = LLAMA_SPLIT_MODE_NONE; - } else if (arg_next == "layer") { - params.split_mode = LLAMA_SPLIT_MODE_LAYER; - } else if (arg_next == "row") { -#ifdef GGML_USE_SYCL - fprintf(stderr, "warning: The split mode value:[row] is not supported by llama.cpp with SYCL. It's developing.\nExit!\n"); - exit(1); -#endif // GGML_USE_SYCL - params.split_mode = LLAMA_SPLIT_MODE_ROW; - } else { - invalid_param = true; - break; - } -#ifndef GGML_USE_CUBLAS_SYCL - fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting the split mode has no effect.\n"); -#endif // GGML_USE_CUBLAS_SYCL - - } else if (arg == "--tensor-split" || arg == "-ts") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::string arg_next = argv[i]; - - // split string by , and / - const std::regex regex{R"([,/]+)"}; - std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; - std::vector split_arg{it, {}}; - if (split_arg.size() >= llama_max_devices()) { - invalid_param = true; - break; - } - for (size_t i = 0; i < llama_max_devices(); ++i) { - if (i < split_arg.size()) { - params.tensor_split[i] = std::stof(split_arg[i]); - } else { - params.tensor_split[i] = 0.0f; - } - } -#ifndef GGML_USE_CUBLAS_SYCL_VULKAN - fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL/Vulkan. Setting a tensor split has no effect.\n"); -#endif // GGML_USE_CUBLAS_SYCL - } else if (arg == "--no-mmap") { - params.use_mmap = false; - } else if (arg == "--numa") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::string value(argv[i]); - /**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } - else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } - else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } - else { invalid_param = true; break; } - } else if (arg == "--verbose-prompt") { - params.verbose_prompt = true; - } else if (arg == "--no-display-prompt") { - params.display_prompt = false; - } else if (arg == "-r" || arg == "--reverse-prompt") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.antiprompt.emplace_back(argv[i]); - } else if (arg == "-ld" || arg == "--logdir") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.logdir = argv[i]; - - if (params.logdir.back() != DIRECTORY_SEPARATOR) { - params.logdir += DIRECTORY_SEPARATOR; - } - } else if (arg == "--save-all-logits" || arg == "--kl-divergence-base") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.logits_file = argv[i]; - } else if (arg == "--perplexity" || arg == "--all-logits") { - params.logits_all = true; - } else if (arg == "--ppl-stride") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.ppl_stride = std::stoi(argv[i]); - } else if (arg == "-ptc" || arg == "--print-token-count") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_print = std::stoi(argv[i]); - } else if (arg == "--ppl-output-type") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.ppl_output_type = std::stoi(argv[i]); - } else if (arg == "--hellaswag") { - params.hellaswag = true; - } else if (arg == "--hellaswag-tasks") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.hellaswag_tasks = std::stoi(argv[i]); - } else if (arg == "--winogrande") { - params.winogrande = true; - } else if (arg == "--winogrande-tasks") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.winogrande_tasks = std::stoi(argv[i]); - } else if (arg == "--multiple-choice") { - params.multiple_choice = true; - } else if (arg == "--multiple-choice-tasks") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.multiple_choice_tasks = std::stoi(argv[i]); - } else if (arg == "--kl-divergence") { - params.kl_divergence = true; - } else if (arg == "--ignore-eos") { - params.ignore_eos = true; - } else if (arg == "--no-penalize-nl") { - sparams.penalize_nl = false; - } else if (arg == "-l" || arg == "--logit-bias") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::stringstream ss(argv[i]); - llama_token key; - char sign; - std::string value_str; - try { - if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { - sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); - } else { - throw std::exception(); - } - } catch (const std::exception&) { - invalid_param = true; - break; - } - } else if (arg == "-h" || arg == "--help") { - return false; - - } else if (arg == "--version") { - fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); - fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); - exit(0); - } else if (arg == "--random-prompt") { - params.random_prompt = true; - } else if (arg == "--in-prefix-bos") { - params.input_prefix_bos = true; - } else if (arg == "--in-prefix") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.input_prefix = argv[i]; - } else if (arg == "--in-suffix") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.input_suffix = argv[i]; - } else if (arg == "--grammar") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.grammar = argv[i]; - } else if (arg == "--grammar-file") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - break; - } - std::copy( - std::istreambuf_iterator(file), - std::istreambuf_iterator(), - std::back_inserter(sparams.grammar) - ); - } else if (arg == "--override-kv") { - if (++i >= argc) { - invalid_param = true; - break; - } - char * sep = strchr(argv[i], '='); - if (sep == nullptr || sep - argv[i] >= 128) { - fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]); - invalid_param = true; - break; - } - struct llama_model_kv_override kvo; - std::strncpy(kvo.key, argv[i], sep - argv[i]); - kvo.key[sep - argv[i]] = 0; - sep++; - if (strncmp(sep, "int:", 4) == 0) { - sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.int_value = std::atol(sep); - } else if (strncmp(sep, "float:", 6) == 0) { - sep += 6; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; - kvo.float_value = std::atof(sep); - } else if (strncmp(sep, "bool:", 5) == 0) { - sep += 5; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; - if (std::strcmp(sep, "true") == 0) { - kvo.bool_value = true; - } else if (std::strcmp(sep, "false") == 0) { - kvo.bool_value = false; - } else { - fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); + switch (constexpr_hash(argv[i])) + { + case constexpr_hash("-s"): + case constexpr_hash("--seed"): + if (++i >= argc) { invalid_param = true; break; } - } else { - fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); - invalid_param = true; + params.seed = std::stoul(argv[i]); + break; + case constexpr_hash("-t"): + case constexpr_hash("--threads"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads = std::stoi(argv[i]); + if (params.n_threads <= 0) { + params.n_threads = std::thread::hardware_concurrency(); + } + break; + case constexpr_hash("-tb"): + case constexpr_hash("--threads-batch"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads_batch = std::stoi(argv[i]); + if (params.n_threads_batch <= 0) { + params.n_threads_batch = std::thread::hardware_concurrency(); + } + break; + case constexpr_hash("-td"): + case constexpr_hash("--threads-draft"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads_draft = std::stoi(argv[i]); + if (params.n_threads_draft <= 0) { + params.n_threads_draft = std::thread::hardware_concurrency(); + } + break; + case constexpr_hash("-tbd"): + case constexpr_hash("--threads-batch-draft"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads_batch_draft = std::stoi(argv[i]); + if (params.n_threads_batch_draft <= 0) { + params.n_threads_batch_draft = std::thread::hardware_concurrency(); + } + break; + case constexpr_hash("-p"): + case constexpr_hash("--prompt"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.prompt = argv[i]; + break; + case constexpr_hash("-e"): + case constexpr_hash("--escape"): + params.escape = true; + break; + case constexpr_hash("--prompt-cache"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.path_prompt_cache = argv[i]; + break; + case constexpr_hash("--prompt-cache-all"): + params.prompt_cache_all = true; + break; + case constexpr_hash("--prompt-cache-ro"): + params.prompt_cache_ro = true; + break; + case constexpr_hash("-bf"): + case constexpr_hash("--binary-file"): { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i], std::ios::binary); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + // store the external file name in params + params.prompt_file = argv[i]; + std::ostringstream ss; + ss << file.rdbuf(); + params.prompt = ss.str(); + fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), argv[i]); break; } - params.kv_overrides.push_back(kvo); + case constexpr_hash("-f"): + case constexpr_hash("--file"): { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + // store the external file name in params + params.prompt_file = argv[i]; + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); + if (!params.prompt.empty() && params.prompt.back() == '\n') { + params.prompt.pop_back(); + } + break; + } + case constexpr_hash("-n"): + case constexpr_hash("--n-predict"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_predict = std::stoi(argv[i]); + break; + case constexpr_hash("--top-k"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.top_k = std::stoi(argv[i]); + break; + case constexpr_hash("-c"): + case constexpr_hash("--ctx-size"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_ctx = std::stoi(argv[i]); + break; + case constexpr_hash("--grp-attn-n"): + case constexpr_hash("-gan"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.grp_attn_n = std::stoi(argv[i]); + break; + case constexpr_hash("--grp-attn-w"): + case constexpr_hash("-gaw"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.grp_attn_w = std::stoi(argv[i]); + break; + case constexpr_hash("--rope-freq-base"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_base = std::stof(argv[i]); + break; + case constexpr_hash("--rope-freq-scale"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_scale = std::stof(argv[i]); + break; + case constexpr_hash("--rope-scaling"): { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string value(argv[i]); + /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } + else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } + else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } + else { invalid_param = true; break; } + break; + } + case constexpr_hash("--rope-scale"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_scale = 1.0f / std::stof(argv[i]); + break; + case constexpr_hash("--yarn-orig-ctx"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.yarn_orig_ctx = std::stoi(argv[i]); + break; + case constexpr_hash("--yarn-ext-factor"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.yarn_ext_factor = std::stof(argv[i]); + break; + case constexpr_hash("--yarn-attn-factor"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.yarn_attn_factor = std::stof(argv[i]); + break; + case constexpr_hash("--yarn-beta-fast"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.yarn_beta_fast = std::stof(argv[i]); + break; + case constexpr_hash("--yarn-beta-slow"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.yarn_beta_slow = std::stof(argv[i]); + break; + case constexpr_hash("--pooling"): { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string value(argv[i]); + /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } + else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } + else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else { invalid_param = true; break; } + break; + } + case constexpr_hash("--defrag-thold"): + case constexpr_hash("-dt"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.defrag_thold = std::stof(argv[i]); + break; + case constexpr_hash("--samplers"): + if (++i >= argc) { + invalid_param = true; + break; + } + { + const auto sampler_names = string_split(argv[i], ';'); + sparams.samplers_sequence = sampler_types_from_names(sampler_names, true); + } + break; + case constexpr_hash("--sampling-seq"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.samplers_sequence = sampler_types_from_chars(argv[i]); + break; + case constexpr_hash("--top-p"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.top_p = std::stof(argv[i]); + break; + case constexpr_hash("--min-p"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.min_p = std::stof(argv[i]); + break; + case constexpr_hash("--temp"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.temp = std::stof(argv[i]); + sparams.temp = std::max(sparams.temp, 0.0f); + break; + case constexpr_hash("--tfs"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.tfs_z = std::stof(argv[i]); + break; + case constexpr_hash("--typical"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.typical_p = std::stof(argv[i]); + break; + case constexpr_hash("--repeat-last-n"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.penalty_last_n = std::stoi(argv[i]); + sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); + break; + case constexpr_hash("--repeat-penalty"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.penalty_repeat = std::stof(argv[i]); + break; + case constexpr_hash("--frequency-penalty"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.penalty_freq = std::stof(argv[i]); + break; + case constexpr_hash("--presence-penalty"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.penalty_present = std::stof(argv[i]); + break; + case constexpr_hash("--dynatemp-range"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.dynatemp_range = std::stof(argv[i]); + break; + case constexpr_hash("--dynatemp-exp"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.dynatemp_exponent = std::stof(argv[i]); + break; + case constexpr_hash("--mirostat"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.mirostat = std::stoi(argv[i]); + break; + case constexpr_hash("--mirostat-lr"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.mirostat_eta = std::stof(argv[i]); + break; + case constexpr_hash("--mirostat-ent"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.mirostat_tau = std::stof(argv[i]); + break; + case constexpr_hash("--cfg-negative-prompt"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.cfg_negative_prompt = argv[i]; + break; + case constexpr_hash("--cfg-negative-prompt-file"): { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(sparams.cfg_negative_prompt)); + if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') { + sparams.cfg_negative_prompt.pop_back(); + } + break; + } + case constexpr_hash("--cfg-scale"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.cfg_scale = std::stof(argv[i]); + break; + case constexpr_hash("-b"): + case constexpr_hash("--batch-size"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_batch = std::stoi(argv[i]); + break; + case constexpr_hash("-ub"): + case constexpr_hash("--ubatch-size"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_ubatch = std::stoi(argv[i]); + break; + case constexpr_hash("--keep"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_keep = std::stoi(argv[i]); + break; + case constexpr_hash("--draft"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_draft = std::stoi(argv[i]); + break; + case constexpr_hash("--chunks"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_chunks = std::stoi(argv[i]); + break; + case constexpr_hash("-np"): + case constexpr_hash("--parallel"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_parallel = std::stoi(argv[i]); + break; + case constexpr_hash("-ns"): + case constexpr_hash("--sequences"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_sequences = std::stoi(argv[i]); + break; + case constexpr_hash("--p-split"): + case constexpr_hash("-ps"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.p_split = std::stof(argv[i]); + break; + case constexpr_hash("-m"): + case constexpr_hash("--model"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.model = argv[i]; + break; + case constexpr_hash("-md"): + case constexpr_hash("--model-draft"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.model_draft = argv[i]; + break; + case constexpr_hash("-a"): + case constexpr_hash("--alias"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.model_alias = argv[i]; + break; + case constexpr_hash("--lora"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_adapter.emplace_back(argv[i], 1.0f); + params.use_mmap = false; + break; + case constexpr_hash("--lora-scaled"): { + if (++i >= argc) { + invalid_param = true; + break; + } + const char * lora_adapter = argv[i]; + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); + params.use_mmap = false; + break; + } + case constexpr_hash("--lora-base"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_base = argv[i]; + break; + case constexpr_hash("--control-vector"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.control_vectors.push_back({ 1.0f, argv[i], }); + break; + case constexpr_hash("--control-vector-scaled"): + if (++i >= argc) { + invalid_param = true; + break; + } + { + const char* fname = argv[i]; + if (++i >= argc) { + invalid_param = true; + break; + } + params.control_vectors.push_back({ std::stof(argv[i]), fname, }); + } + break; + case constexpr_hash("--control-vector-layer-range"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.control_vector_layer_start = std::stoi(argv[i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.control_vector_layer_end = std::stoi(argv[i]); + break; + case constexpr_hash("--mmproj"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.mmproj = argv[i]; + break; + case constexpr_hash("--image"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.image = argv[i]; + break; + case constexpr_hash("-i"): + case constexpr_hash("--interactive"): + params.interactive = true; + break; + case constexpr_hash("--embedding"): + params.embedding = true; + break; + case constexpr_hash("--interactive-first"): + params.interactive_first = true; + break; + case constexpr_hash("-ins"): + case constexpr_hash("--instruct"): + params.instruct = true; + break; + case constexpr_hash("-cml"): + case constexpr_hash("--chatml"): + params.chatml = true; + break; + case constexpr_hash("--infill"): + params.infill = true; + break; + case constexpr_hash("-dkvc"): + case constexpr_hash("--dump-kv-cache"): + params.dump_kv_cache = true; + break; + case constexpr_hash("-nkvo"): + case constexpr_hash("--no-kv-offload"): + params.no_kv_offload = true; + break; + case constexpr_hash("-ctk"): + case constexpr_hash("--cache-type-k"): + params.cache_type_k = argv[++i]; + break; + case constexpr_hash("-ctv"): + case constexpr_hash("--cache-type-v"): + params.cache_type_v = argv[++i]; + break; + case constexpr_hash("--multiline-input"): + params.multiline_input = true; + break; + case constexpr_hash("--simple-io"): + params.simple_io = true; + break; + case constexpr_hash("-cb"): + case constexpr_hash("--cont-batching"): + params.cont_batching = true; + break; + case constexpr_hash("--color"): + params.use_color = true; + break; + case constexpr_hash("--mlock"): + params.use_mlock = true; + break; + case constexpr_hash("--gpu-layers"): + case constexpr_hash("-ngl"): + case constexpr_hash("--n-gpu-layers"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_gpu_layers = std::stoi(argv[i]); + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); + fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); + } + break; + case constexpr_hash("--gpu-layers-draft"): + case constexpr_hash("-ngld"): + case constexpr_hash("--n-gpu-layers-draft"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_gpu_layers_draft = std::stoi(argv[i]); + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n"); + fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); + } + break; + case constexpr_hash("--main-gpu"): + case constexpr_hash("-mg"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.main_gpu = std::stoi(argv[i]); +#ifndef GGML_USE_CUBLAS_SYCL + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting the main GPU has no effect.\n"); +#endif // GGML_USE_CUBLAS_SYCL + break; + case constexpr_hash("--split-mode"): + case constexpr_hash("-sm"): { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string arg_next = argv[i]; + if (arg_next == "none") { + params.split_mode = LLAMA_SPLIT_MODE_NONE; + } else if (arg_next == "layer") { + params.split_mode = LLAMA_SPLIT_MODE_LAYER; + } else if (arg_next == "row") { +#ifdef GGML_USE_SYCL + fprintf(stderr, "warning: The split mode value:[row] is not supported by llama.cpp with SYCL. It's developing.\nExit!\n"); + exit(1); +#endif // GGML_USE_SYCL + params.split_mode = LLAMA_SPLIT_MODE_ROW; + } else { + invalid_param = true; + break; + } +#ifndef GGML_USE_CUBLAS_SYCL + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting the split mode has no effect.\n"); +#endif // GGML_USE_CUBLAS_SYCL + break; + } + case constexpr_hash("--tensor-split"): + case constexpr_hash("-ts"): { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + if (split_arg.size() >= llama_max_devices()) { + invalid_param = true; + break; + } + for (size_t i = 0; i < llama_max_devices(); ++i) { + if (i < split_arg.size()) { + params.tensor_split[i] = std::stof(split_arg[i]); + } else { + params.tensor_split[i] = 0.0f; + } + } +#ifndef GGML_USE_CUBLAS_SYCL_VULKAN + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL/Vulkan. Setting a tensor split has no effect.\n"); +#endif // GGML_USE_CUBLAS_SYCL + break; + } + case constexpr_hash("--no-mmap"): + params.use_mmap = false; + break; + case constexpr_hash("--numa"): { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string value(argv[i]); + /**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } + else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } + else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } + else { invalid_param = true; break; } + break; + } + case constexpr_hash("--verbose-prompt"): + params.verbose_prompt = true; + break; + case constexpr_hash("--no-display-prompt"): + params.display_prompt = false; + break; + case constexpr_hash("-r"): + case constexpr_hash("--reverse-prompt"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.antiprompt.emplace_back(argv[i]); + break; + case constexpr_hash("-ld"): + case constexpr_hash("--logdir"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.logdir = argv[i]; + + if (params.logdir.back() != DIRECTORY_SEPARATOR) { + params.logdir += DIRECTORY_SEPARATOR; + } + break; + case constexpr_hash("--save-all-logits"): + case constexpr_hash("--kl-divergence-base"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.logits_file = argv[i]; + break; + case constexpr_hash("--perplexity"): + case constexpr_hash("--all-logits"): + params.logits_all = true; + break; + case constexpr_hash("--ppl-stride"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.ppl_stride = std::stoi(argv[i]); + break; + case constexpr_hash("-ptc"): + case constexpr_hash("--print-token-count"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_print = std::stoi(argv[i]); + break; + case constexpr_hash("--ppl-output-type"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.ppl_output_type = std::stoi(argv[i]); + break; + case constexpr_hash("--hellaswag"): + params.hellaswag = true; + break; + case constexpr_hash("--hellaswag-tasks"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.hellaswag_tasks = std::stoi(argv[i]); + break; + case constexpr_hash("--winogrande"): + params.winogrande = true; + break; + case constexpr_hash("--winogrande-tasks"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.winogrande_tasks = std::stoi(argv[i]); + break; + case constexpr_hash("--multiple-choice"): + params.multiple_choice = true; + break; + case constexpr_hash("--multiple-choice-tasks"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.multiple_choice_tasks = std::stoi(argv[i]); + break; + case constexpr_hash("--kl-divergence"): + params.kl_divergence = true; + break; + case constexpr_hash("--ignore-eos"): + params.ignore_eos = true; + break; + case constexpr_hash("--no-penalize-nl"): + sparams.penalize_nl = false; + break; + case constexpr_hash("-l"): + case constexpr_hash("--logit-bias"): { + if (++i >= argc) { + invalid_param = true; + break; + } + std::stringstream ss(argv[i]); + llama_token key; + char sign; + std::string value_str; + try { + if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { + sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + } else { + throw std::exception(); + } + } catch (const std::exception&) { + invalid_param = true; + break; + } + break; + } + case constexpr_hash("-h"): + case constexpr_hash("--help"): + return false; + case constexpr_hash("--version"): + fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); + fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); + exit(0); + break; + case constexpr_hash("--random-prompt"): + params.random_prompt = true; + break; + case constexpr_hash("--in-prefix-bos"): + params.input_prefix_bos = true; + break; + case constexpr_hash("--in-prefix"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.input_prefix = argv[i]; + break; + case constexpr_hash("--in-suffix"): + if (++i >= argc) { + invalid_param = true; + break; + } + params.input_suffix = argv[i]; + break; + case constexpr_hash("--grammar"): + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.grammar = argv[i]; + break; + case constexpr_hash("--grammar-file"): { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(sparams.grammar) + ); + break; + } + case constexpr_hash("--override-kv"): { + if (++i >= argc) { + invalid_param = true; + break; + } + char * sep = strchr(argv[i], '='); + if (sep == nullptr || sep - argv[i] >= 128) { + fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]); + invalid_param = true; + break; + } + struct llama_model_kv_override kvo; + std::strncpy(kvo.key, argv[i], sep - argv[i]); + kvo.key[sep - argv[i]] = 0; + sep++; + if (strncmp(sep, "int:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; + kvo.int_value = std::atol(sep); + } else if (strncmp(sep, "float:", 6) == 0) { + sep += 6; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; + kvo.float_value = std::atof(sep); + } else if (strncmp(sep, "bool:", 5) == 0) { + sep += 5; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; + if (std::strcmp(sep, "true") == 0) { + kvo.bool_value = true; + } else if (std::strcmp(sep, "false") == 0) { + kvo.bool_value = false; + } else { + fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); + invalid_param = true; + break; + } + } else { + fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); + invalid_param = true; + break; + } + params.kv_overrides.push_back(kvo); + break; + } + default: #ifndef LOG_DISABLE_LOGS - // Parse args for logging parameters - } else if ( log_param_single_parse( argv[i] ) ) { - // Do nothing, log_param_single_parse automatically does it's thing - // and returns if a match was found and parsed. - } else if ( log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i] ) ) { - // We have a matching known parameter requiring an argument, - // now we need to check if there is anything after this argv - // and flag invalid_param or parse it. - if (++i >= argc) { - invalid_param = true; - break; - } - if( !log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i-1], argv[i]) ) { - invalid_param = true; - break; - } - // End of Parse args for logging parameters + // Parse args for logging parameters + if ( log_param_single_parse( argv[i] ) ) { + // Do nothing, log_param_single_parse automatically does it's thing + // and returns if a match was found and parsed. + break; + } else if ( log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i] ) ) { + // We have a matching known parameter requiring an argument, + // now we need to check if there is anything after this argv + // and flag invalid_param or parse it. + if (++i >= argc) { + invalid_param = true; + break; + } + if( !log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i-1], argv[i]) ) { + invalid_param = true; + break; + } + // End of Parse args for logging parameters + } #endif // LOG_DISABLE_LOGS - } else { - throw std::invalid_argument("error: unknown argument: " + arg); + throw std::invalid_argument("error: unknown argument: " + arg); } + + if (invalid_param) + break; } if (invalid_param) { throw std::invalid_argument("error: invalid parameter for argument: " + arg);