cvector: better prompt handling, add "mean vector" method (#8069)

* remove completions file

* fix inverted vector

* add mean method

* code style

* remove inverted pca hotfix
This commit is contained in:
Xuan Son Nguyen 2024-06-25 13:59:54 +02:00 committed by GitHub
parent 48e6b92cc3
commit 49c03c79cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 133 additions and 60 deletions

View file

@ -1263,11 +1263,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
// cvector params
if (arg == "--completions-file") {
CHECK_ARG
params.cvector_completions_file = argv[i];
return true;
}
if (arg == "--positive-file") {
CHECK_ARG
params.cvector_positive_file = argv[i];
@ -1278,11 +1273,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.cvector_negative_file = argv[i];
return true;
}
if (arg == "--completions") {
CHECK_ARG
params.n_completions = std::stoi(argv[i]);
return true;
}
if (arg == "--pca-batch") {
CHECK_ARG
params.n_pca_batch = std::stoi(argv[i]);
@ -1293,6 +1283,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.n_pca_iterations = std::stoi(argv[i]);
return true;
}
if (arg == "--method") {
CHECK_ARG
std::string value(argv[i]);
/**/ if (value == "pca") { params.cvector_dimre_method = DIMRE_METHOD_PCA; }
else if (value == "mean") { params.cvector_dimre_method = DIMRE_METHOD_MEAN; }
else { invalid_param = true; }
return true;
}
#ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters
if (log_param_single_parse(argv[i])) {
@ -1626,11 +1624,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "cvector", "-o, --output FNAME", "output file (default: '%s')", params.cvector_outfile.c_str() });
options.push_back({ "cvector", " --positive-file FNAME", "positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str() });
options.push_back({ "cvector", " --negative-file FNAME", "negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str() });
options.push_back({ "cvector", " --completions-file FNAME",
"completions file (default: '%s')", params.cvector_completions_file.c_str() });
options.push_back({ "cvector", " --completions N", "number of lines of completions file to use (default: %d)", params.n_completions });
options.push_back({ "cvector", " --pca-batch N", "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch });
options.push_back({ "cvector", " --pca-iter N", "number of iterations used for PCA (default: %d)", params.n_pca_iterations });
options.push_back({ "cvector", " --method {pca,mean}", "dimensionality reduction method to be used (default: pca)" });
printf("usage: %s [options]\n", argv[0]);

View file

@ -52,6 +52,12 @@ int32_t cpu_get_num_math();
// CLI argument parsing
//
// dimensionality reduction methods, used by cvector-generator
enum dimre_method {
DIMRE_METHOD_PCA,
DIMRE_METHOD_MEAN,
};
struct gpt_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
@ -238,13 +244,12 @@ struct gpt_params {
bool compute_ppl = true; // whether to compute perplexity
// cvector-generator params
int n_completions = 64;
int n_pca_batch = 20;
int n_pca_batch = 100;
int n_pca_iterations = 1000;
std::string cvector_outfile = "control_vector.gguf";
std::string cvector_completions_file = "examples/cvector-generator/completions.txt";
std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
std::string cvector_outfile = "control_vector.gguf";
std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
};
void gpt_params_handle_model_default(gpt_params & params);