From 62560367aa1ecf1d75df3baffee6e8dbff62fd7c Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Fri, 31 May 2024 21:27:14 -0400 Subject: [PATCH] add command-line args for num threads, num completions file lines, always reload model refactored a few things and did what the commit message says on the tin --- .../control-vector-generator.cpp | 73 ++++++++++++++++--- 1 file changed, 62 insertions(+), 11 deletions(-) diff --git a/examples/control-vector-generator/control-vector-generator.cpp b/examples/control-vector-generator/control-vector-generator.cpp index 2541fcb27..33da54ec7 100644 --- a/examples/control-vector-generator/control-vector-generator.cpp +++ b/examples/control-vector-generator/control-vector-generator.cpp @@ -39,13 +39,21 @@ struct callback_data { }; struct ctrl_params { + /* default meta parameters */ + bool always_reload = false; + int n_completions = 64; + int n_threads = 8; + + /* default filepaths */ std::string outfile = "control_vector.gguf"; std::string completions_file = "examples/control-vector-generator/completions.txt"; - /* pair of prompts to be used for generating the vectors */ std::string positive_prompts_file = "examples/control-vector-generator/positive.txt"; std::string negative_prompts_file = "examples/control-vector-generator/negative.txt"; + + /* pair of prompts to be used for generating the vectors */ std::vector positive_prompts; std::vector negative_prompts; + /* pair of prompts to be used for testing */ std::vector positive_entries; std::vector negative_entries; @@ -59,10 +67,19 @@ static void print_usage(const char * executable) { printf("\n"); printf("options:\n"); printf(" -h, --help show this help message and exit\n"); - printf(" -o, --outfile output file (default: 'control_vector.gguf')\n"); - printf(" -cf, --completions-file completions file (default: 'examples/control-vector-generator/completions.txt')\n"); - printf(" -pf, --positive-file positive prompts file, one prompt per line (default: 'examples/control-vector-generator/positive.txt')\n"); - printf(" -nf, --negative-file negative prompts file, one prompt per line (default: 'examples/control-vector-generator/negative.txt')\n"); + printf(" -t, --num-threads number of threads to use (do not confuse with gpt-opts -t)\n"); + printf(" default: 8\n"); + printf(" -o, --outfile output file\n"); + printf(" default: 'control_vector.gguf'\n"); + printf(" -pf, --positive-file positive prompts file, one prompt per line\n"); + printf(" default: 'examples/control-vector-generator/positive.txt'\n"); + printf(" -nf, --negative-file negative prompts file, one prompt per line\n"); + printf(" default: 'examples/control-vector-generator/negative.txt'\n"); + printf(" -cf, --completions-file completions file\n"); + printf(" default: 'examples/control-vector-generator/completions.txt'\n"); + printf(" -nc, --num-completions number of lines of completions file to use\n"); + printf(" default: 64\n"); + printf(" --always-reload reload the model for every new template to parse\n"); printf("\n"); printf("gpt-opts:\n"); printf(" other options from main\n"); @@ -122,6 +139,36 @@ static int ctrlvec_params_parse_ex(int argc, char ** argv, ctrl_params & params) throw std::invalid_argument("error: missing argument for " + arg); } } + if (arg == "--num-completions" || arg == "-nc") { + if (++arg_idx < argc && strncmp(argv[arg_idx], arg_prefix.c_str(), 2) != 0) { + try { + params.n_completions = std::stoi(argv[arg_idx]); + } + catch (const std::invalid_argument & ex) { + throw std::invalid_argument("error: invalid argument for " + arg); + } + skipme += 2; + } else { + throw std::invalid_argument("error: missing argument for " + arg); + } + } + if (arg == "--num-threads" || arg == "-t") { + if (++arg_idx < argc && strncmp(argv[arg_idx], arg_prefix.c_str(), 2) != 0) { + try { + params.n_threads = std::stoi(argv[arg_idx]); + } + catch (const std::invalid_argument & ex) { + throw std::invalid_argument("error: invalid argument for " + arg); + } + skipme += 2; + } else { + throw std::invalid_argument("error: missing argument for " + arg); + } + } + if (arg == "--always-reload") { + params.always_reload = true; + skipme += 1; + } // TODO it might be nice QoL to have single positive/negative args // we do not handle any other unknown arguments here because they will be handled by gpt_parse_params } @@ -168,11 +215,13 @@ static std::string format_template(std::string persona, std::string suffix) { static void populate_entries(ctrl_params & cparams, std::string positive, std::string negative) { std::string line; std::ifstream completions_file(cparams.completions_file); + int i = 0; if (completions_file.is_open()) { - while (std::getline(completions_file, line)) { + while (std::getline(completions_file, line) && i < cparams.n_completions) { // TODO replicate the truncations done by the python implementation cparams.positive_entries.push_back(format_template(positive, line)); cparams.negative_entries.push_back(format_template(negative, line)); + i++; } completions_file.close(); } else { @@ -409,8 +458,7 @@ static std::vector power_iteration(callback_data & cb_data, const float * } // TODO translate to ggml -static void pca(callback_data & cb_data) { - size_t n_threads = 8; +static void pca(callback_data & cb_data, size_t n_threads) { int n_layers = cb_data.v_diff.size(); std::vector threads; cb_data.v_final.reserve(n_layers); @@ -561,15 +609,18 @@ int main(int argc, char ** argv) { // need to reload the model so it doesn't run out of context // this should scale with -c option passed by main - // TODO maybe we want to add an option to reload for every new prompt token_ct += 2 * max_seq_len; - if (token_ct >= n_ctx) { + if (token_ct > n_ctx || cparams.always_reload) { //break; llama_free(ctx); llama_free_model(model); std::tie(model, ctx) = llama_init_from_gpt_params(params); token_ct = 2 * max_seq_len; } + if (token_ct > n_ctx) { + fprintf(stderr, "context size exceeded on iteration %d\n", i); + break; + } printf("Evaluating prompt: \"%s\" - \"%s\" (%ld tokens)\n", positive_prompt.c_str(), negative_prompt.c_str(), max_seq_len); @@ -590,7 +641,7 @@ int main(int argc, char ** argv) { } concatenate_diffs(cb_data); - pca(cb_data); + pca(cb_data, cparams.n_threads); printf("v_final %f %f \n", cb_data.v_final[0][0], cb_data.v_final[0][1]); llama_free(ctx);