add command-line args for num threads, num completions file lines, always reload model

refactored a few things and did what the commit message says on the tin
This commit is contained in:
Christian Zhou-Zheng 2024-05-31 21:27:14 -04:00
parent 4d7d71bc43
commit 62560367aa

View file

@ -39,13 +39,21 @@ struct callback_data {
};
struct ctrl_params {
/* default meta parameters */
bool always_reload = false;
int n_completions = 64;
int n_threads = 8;
/* default filepaths */
std::string outfile = "control_vector.gguf";
std::string completions_file = "examples/control-vector-generator/completions.txt";
/* pair of prompts to be used for generating the vectors */
std::string positive_prompts_file = "examples/control-vector-generator/positive.txt";
std::string negative_prompts_file = "examples/control-vector-generator/negative.txt";
/* pair of prompts to be used for generating the vectors */
std::vector<std::string> positive_prompts;
std::vector<std::string> negative_prompts;
/* pair of prompts to be used for testing */
std::vector<std::string> positive_entries;
std::vector<std::string> negative_entries;
@ -59,10 +67,19 @@ static void print_usage(const char * executable) {
printf("\n");
printf("options:\n");
printf(" -h, --help show this help message and exit\n");
printf(" -o, --outfile output file (default: 'control_vector.gguf')\n");
printf(" -cf, --completions-file completions file (default: 'examples/control-vector-generator/completions.txt')\n");
printf(" -pf, --positive-file positive prompts file, one prompt per line (default: 'examples/control-vector-generator/positive.txt')\n");
printf(" -nf, --negative-file negative prompts file, one prompt per line (default: 'examples/control-vector-generator/negative.txt')\n");
printf(" -t, --num-threads number of threads to use (do not confuse with gpt-opts -t)\n");
printf(" default: 8\n");
printf(" -o, --outfile output file\n");
printf(" default: 'control_vector.gguf'\n");
printf(" -pf, --positive-file positive prompts file, one prompt per line\n");
printf(" default: 'examples/control-vector-generator/positive.txt'\n");
printf(" -nf, --negative-file negative prompts file, one prompt per line\n");
printf(" default: 'examples/control-vector-generator/negative.txt'\n");
printf(" -cf, --completions-file completions file\n");
printf(" default: 'examples/control-vector-generator/completions.txt'\n");
printf(" -nc, --num-completions number of lines of completions file to use\n");
printf(" default: 64\n");
printf(" --always-reload reload the model for every new template to parse\n");
printf("\n");
printf("gpt-opts:\n");
printf(" other options from main\n");
@ -122,6 +139,36 @@ static int ctrlvec_params_parse_ex(int argc, char ** argv, ctrl_params & params)
throw std::invalid_argument("error: missing argument for " + arg);
}
}
if (arg == "--num-completions" || arg == "-nc") {
if (++arg_idx < argc && strncmp(argv[arg_idx], arg_prefix.c_str(), 2) != 0) {
try {
params.n_completions = std::stoi(argv[arg_idx]);
}
catch (const std::invalid_argument & ex) {
throw std::invalid_argument("error: invalid argument for " + arg);
}
skipme += 2;
} else {
throw std::invalid_argument("error: missing argument for " + arg);
}
}
if (arg == "--num-threads" || arg == "-t") {
if (++arg_idx < argc && strncmp(argv[arg_idx], arg_prefix.c_str(), 2) != 0) {
try {
params.n_threads = std::stoi(argv[arg_idx]);
}
catch (const std::invalid_argument & ex) {
throw std::invalid_argument("error: invalid argument for " + arg);
}
skipme += 2;
} else {
throw std::invalid_argument("error: missing argument for " + arg);
}
}
if (arg == "--always-reload") {
params.always_reload = true;
skipme += 1;
}
// TODO it might be nice QoL to have single positive/negative args
// we do not handle any other unknown arguments here because they will be handled by gpt_parse_params
}
@ -168,11 +215,13 @@ static std::string format_template(std::string persona, std::string suffix) {
static void populate_entries(ctrl_params & cparams, std::string positive, std::string negative) {
std::string line;
std::ifstream completions_file(cparams.completions_file);
int i = 0;
if (completions_file.is_open()) {
while (std::getline(completions_file, line)) {
while (std::getline(completions_file, line) && i < cparams.n_completions) {
// TODO replicate the truncations done by the python implementation
cparams.positive_entries.push_back(format_template(positive, line));
cparams.negative_entries.push_back(format_template(negative, line));
i++;
}
completions_file.close();
} else {
@ -409,8 +458,7 @@ static std::vector<float> power_iteration(callback_data & cb_data, const float *
}
// TODO translate to ggml
static void pca(callback_data & cb_data) {
size_t n_threads = 8;
static void pca(callback_data & cb_data, size_t n_threads) {
int n_layers = cb_data.v_diff.size();
std::vector<std::thread> threads;
cb_data.v_final.reserve(n_layers);
@ -561,15 +609,18 @@ int main(int argc, char ** argv) {
// need to reload the model so it doesn't run out of context
// this should scale with -c option passed by main
// TODO maybe we want to add an option to reload for every new prompt
token_ct += 2 * max_seq_len;
if (token_ct >= n_ctx) {
if (token_ct > n_ctx || cparams.always_reload) {
//break;
llama_free(ctx);
llama_free_model(model);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
token_ct = 2 * max_seq_len;
}
if (token_ct > n_ctx) {
fprintf(stderr, "context size exceeded on iteration %d\n", i);
break;
}
printf("Evaluating prompt: \"%s\" - \"%s\" (%ld tokens)\n", positive_prompt.c_str(), negative_prompt.c_str(), max_seq_len);
@ -590,7 +641,7 @@ int main(int argc, char ** argv) {
}
concatenate_diffs(cb_data);
pca(cb_data);
pca(cb_data, cparams.n_threads);
printf("v_final %f %f \n", cb_data.v_final[0][0], cb_data.v_final[0][1]);
llama_free(ctx);