From 4754bc22b695152a658e089eb8482d25018daf80 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Sat, 12 Aug 2023 05:39:02 -0600 Subject: [PATCH] Add --cfg-negative-prompt-file option for examples --- examples/common.cpp | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/examples/common.cpp b/examples/common.cpp index 9f8aab9a2..c13b3dfd9 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -274,6 +274,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.cfg_negative_prompt = argv[i]; + } else if (arg == "--cfg-negative-prompt-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + back_inserter(params.cfg_negative_prompt)); + if (params.cfg_negative_prompt.back() == '\n') { + params.cfg_negative_prompt.pop_back(); + } } else if (arg == "--cfg-scale") { if (++i >= argc) { invalid_param = true; @@ -567,8 +585,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); fprintf(stdout, " --grammar-file FNAME file to read grammar from\n"); - fprintf(stdout, " --cfg-negative-prompt PROMPT \n"); + fprintf(stdout, " --cfg-negative-prompt PROMPT\n"); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n"); + fprintf(stdout, " --cfg-negative-prompt-file FNAME\n"); + fprintf(stdout, " negative prompt file to use for guidance. (default: empty)\n"); fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale); fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);