add command-line args for num threads, num completions file lines, always reload model
refactored a few things and did what the commit message says on the tin
This commit is contained in:
parent
4d7d71bc43
commit
62560367aa
1 changed files with 62 additions and 11 deletions
|
@ -39,13 +39,21 @@ struct callback_data {
|
|||
};
|
||||
|
||||
struct ctrl_params {
|
||||
/* default meta parameters */
|
||||
bool always_reload = false;
|
||||
int n_completions = 64;
|
||||
int n_threads = 8;
|
||||
|
||||
/* default filepaths */
|
||||
std::string outfile = "control_vector.gguf";
|
||||
std::string completions_file = "examples/control-vector-generator/completions.txt";
|
||||
/* pair of prompts to be used for generating the vectors */
|
||||
std::string positive_prompts_file = "examples/control-vector-generator/positive.txt";
|
||||
std::string negative_prompts_file = "examples/control-vector-generator/negative.txt";
|
||||
|
||||
/* pair of prompts to be used for generating the vectors */
|
||||
std::vector<std::string> positive_prompts;
|
||||
std::vector<std::string> negative_prompts;
|
||||
|
||||
/* pair of prompts to be used for testing */
|
||||
std::vector<std::string> positive_entries;
|
||||
std::vector<std::string> negative_entries;
|
||||
|
@ -59,10 +67,19 @@ static void print_usage(const char * executable) {
|
|||
printf("\n");
|
||||
printf("options:\n");
|
||||
printf(" -h, --help show this help message and exit\n");
|
||||
printf(" -o, --outfile output file (default: 'control_vector.gguf')\n");
|
||||
printf(" -cf, --completions-file completions file (default: 'examples/control-vector-generator/completions.txt')\n");
|
||||
printf(" -pf, --positive-file positive prompts file, one prompt per line (default: 'examples/control-vector-generator/positive.txt')\n");
|
||||
printf(" -nf, --negative-file negative prompts file, one prompt per line (default: 'examples/control-vector-generator/negative.txt')\n");
|
||||
printf(" -t, --num-threads number of threads to use (do not confuse with gpt-opts -t)\n");
|
||||
printf(" default: 8\n");
|
||||
printf(" -o, --outfile output file\n");
|
||||
printf(" default: 'control_vector.gguf'\n");
|
||||
printf(" -pf, --positive-file positive prompts file, one prompt per line\n");
|
||||
printf(" default: 'examples/control-vector-generator/positive.txt'\n");
|
||||
printf(" -nf, --negative-file negative prompts file, one prompt per line\n");
|
||||
printf(" default: 'examples/control-vector-generator/negative.txt'\n");
|
||||
printf(" -cf, --completions-file completions file\n");
|
||||
printf(" default: 'examples/control-vector-generator/completions.txt'\n");
|
||||
printf(" -nc, --num-completions number of lines of completions file to use\n");
|
||||
printf(" default: 64\n");
|
||||
printf(" --always-reload reload the model for every new template to parse\n");
|
||||
printf("\n");
|
||||
printf("gpt-opts:\n");
|
||||
printf(" other options from main\n");
|
||||
|
@ -122,6 +139,36 @@ static int ctrlvec_params_parse_ex(int argc, char ** argv, ctrl_params & params)
|
|||
throw std::invalid_argument("error: missing argument for " + arg);
|
||||
}
|
||||
}
|
||||
if (arg == "--num-completions" || arg == "-nc") {
|
||||
if (++arg_idx < argc && strncmp(argv[arg_idx], arg_prefix.c_str(), 2) != 0) {
|
||||
try {
|
||||
params.n_completions = std::stoi(argv[arg_idx]);
|
||||
}
|
||||
catch (const std::invalid_argument & ex) {
|
||||
throw std::invalid_argument("error: invalid argument for " + arg);
|
||||
}
|
||||
skipme += 2;
|
||||
} else {
|
||||
throw std::invalid_argument("error: missing argument for " + arg);
|
||||
}
|
||||
}
|
||||
if (arg == "--num-threads" || arg == "-t") {
|
||||
if (++arg_idx < argc && strncmp(argv[arg_idx], arg_prefix.c_str(), 2) != 0) {
|
||||
try {
|
||||
params.n_threads = std::stoi(argv[arg_idx]);
|
||||
}
|
||||
catch (const std::invalid_argument & ex) {
|
||||
throw std::invalid_argument("error: invalid argument for " + arg);
|
||||
}
|
||||
skipme += 2;
|
||||
} else {
|
||||
throw std::invalid_argument("error: missing argument for " + arg);
|
||||
}
|
||||
}
|
||||
if (arg == "--always-reload") {
|
||||
params.always_reload = true;
|
||||
skipme += 1;
|
||||
}
|
||||
// TODO it might be nice QoL to have single positive/negative args
|
||||
// we do not handle any other unknown arguments here because they will be handled by gpt_parse_params
|
||||
}
|
||||
|
@ -168,11 +215,13 @@ static std::string format_template(std::string persona, std::string suffix) {
|
|||
static void populate_entries(ctrl_params & cparams, std::string positive, std::string negative) {
|
||||
std::string line;
|
||||
std::ifstream completions_file(cparams.completions_file);
|
||||
int i = 0;
|
||||
if (completions_file.is_open()) {
|
||||
while (std::getline(completions_file, line)) {
|
||||
while (std::getline(completions_file, line) && i < cparams.n_completions) {
|
||||
// TODO replicate the truncations done by the python implementation
|
||||
cparams.positive_entries.push_back(format_template(positive, line));
|
||||
cparams.negative_entries.push_back(format_template(negative, line));
|
||||
i++;
|
||||
}
|
||||
completions_file.close();
|
||||
} else {
|
||||
|
@ -409,8 +458,7 @@ static std::vector<float> power_iteration(callback_data & cb_data, const float *
|
|||
}
|
||||
|
||||
// TODO translate to ggml
|
||||
static void pca(callback_data & cb_data) {
|
||||
size_t n_threads = 8;
|
||||
static void pca(callback_data & cb_data, size_t n_threads) {
|
||||
int n_layers = cb_data.v_diff.size();
|
||||
std::vector<std::thread> threads;
|
||||
cb_data.v_final.reserve(n_layers);
|
||||
|
@ -561,15 +609,18 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// need to reload the model so it doesn't run out of context
|
||||
// this should scale with -c option passed by main
|
||||
// TODO maybe we want to add an option to reload for every new prompt
|
||||
token_ct += 2 * max_seq_len;
|
||||
if (token_ct >= n_ctx) {
|
||||
if (token_ct > n_ctx || cparams.always_reload) {
|
||||
//break;
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||
token_ct = 2 * max_seq_len;
|
||||
}
|
||||
if (token_ct > n_ctx) {
|
||||
fprintf(stderr, "context size exceeded on iteration %d\n", i);
|
||||
break;
|
||||
}
|
||||
|
||||
printf("Evaluating prompt: \"%s\" - \"%s\" (%ld tokens)\n", positive_prompt.c_str(), negative_prompt.c_str(), max_seq_len);
|
||||
|
||||
|
@ -590,7 +641,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
concatenate_diffs(cb_data);
|
||||
pca(cb_data);
|
||||
pca(cb_data, cparams.n_threads);
|
||||
printf("v_final %f %f \n", cb_data.v_final[0][0], cb_data.v_final[0][1]);
|
||||
|
||||
llama_free(ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue