define retrieval-only parameters in retrieval.cpp

This commit is contained in:
Minsoo Cheong 2024-03-24 20:48:38 +09:00
parent f9dc033797
commit 56b7db971e
3 changed files with 84 additions and 54 deletions

View file

@ -154,7 +154,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
return result;
}
static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param) {
bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param) {
std::string arg = argv[i];
llama_sampling_params& sparams = params.sparams;
@ -276,43 +276,6 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
}
return true;
}
if (arg == "--context-files") {
if (++i >= argc) {
invalid_param = true;
return true;
}
while(true) {
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
// store the external file name in params
params.context_files.push_back(argv[i]);
if (i + 1 >= argc || argv[i + 1][0] == '-') {
break;
}
i++;
}
return true;
}
if (arg == "--chunk-size") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.chunk_size = std::stoi(argv[i]);
return true;
}
if (arg == "--chunk-separator") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.chunk_separator = argv[i];
return true;
}
if (arg == "-n" || arg == "--n-predict") {
if (++i >= argc) {
invalid_param = true;
@ -1319,11 +1282,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" prompt file to start generation.\n");
printf(" -bf FNAME, --binary-file FNAME\n");
printf(" binary file containing multiple choice tasks.\n");
printf(" --context-files FNAME1 FNAME2...\n");
printf(" files containing context to embed.\n");
printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size);
printf(" --chunk-separator STRING\n");
printf(" string to separate chunks (default: \"\\n\")\n");
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch);

View file

@ -79,9 +79,6 @@ struct gpt_params {
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold
std::vector<std::string> context_files;// context files to embed
int32_t chunk_size = 64; // chunk size for context embedding
std::string chunk_separator = "\n"; // chunk separator for context embedding
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
@ -170,6 +167,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param);
std::string get_system_info(const gpt_params & params);
std::string gpt_random_prompt(std::mt19937 & rng);

View file

@ -4,6 +4,80 @@
#include <algorithm>
#include <fstream>
struct retrieval_params {
std::vector<std::string> context_files;// context files to embed
int32_t chunk_size = 64; // chunk size for context embedding
std::string chunk_separator = "\n"; // chunk separator for context embedding
};
static void retrieval_params_print_usage(int argc, char** argv, gpt_params & gpt_params, retrieval_params & params) {
fprintf(stderr, "usage: retrieval [options]\n");
fprintf(stderr, "options:\n");
printf(" --context-files FNAME1 FNAME2...\n");
printf(" files containing context to embed.\n");
printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size);
printf(" --chunk-separator STRING\n");
printf(" string to separate chunks (default: \"\\n\")\n");
gpt_print_usage(argc, argv, gpt_params);
}
static void retrieval_params_parse(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & retrieval_params) {
int i = 1;
std::string arg;
while(i < argc) {
arg = argv[i];
bool invalid_gpt_param = false;
if(gpt_params_find_arg(argc, argv, gpt_params, i, invalid_gpt_param)) {
if (invalid_gpt_param) {
fprintf(stderr, "error: invalid argument: %s\n", arg.c_str());
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
// option was parsed by gpt_params_find_arg
} else if (arg == "--context-files") {
if (++i >= argc) {
fprintf(stderr, "error: missing argument for --context-files\n");
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
while(true) {
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
// store the external file name in params
retrieval_params.context_files.push_back(argv[i]);
if (i + 1 >= argc || argv[i + 1][0] == '-') {
break;
}
i++;
}
} else if (arg == "--chunk-size") {
if (++i >= argc) {
fprintf(stderr, "error: missing argument for --chunk-size\n");
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
retrieval_params.chunk_size = std::stoi(argv[i]);
} else if (arg == "--chunk-separator") {
if (++i >= argc) {
fprintf(stderr, "error: missing argument for --chunk-separator\n");
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
retrieval_params.chunk_separator = argv[i];
} else {
// unknown argument
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
i++;
}
}
struct chunk {
// filename
std::string filename;
@ -103,19 +177,18 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
int main(int argc, char ** argv) {
gpt_params params;
retrieval_params retrieval_params;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
retrieval_params_parse(argc, argv, params, retrieval_params);
// For BERT models, batch size must be equal to ubatch size
params.n_ubatch = params.n_batch;
if (params.chunk_size <= 0) {
if (retrieval_params.chunk_size <= 0) {
fprintf(stderr, "chunk_size must be positive\n");
return 1;
}
if (params.context_files.empty()) {
if (retrieval_params.context_files.empty()) {
fprintf(stderr, "context_files must be specified\n");
return 1;
}
@ -128,13 +201,13 @@ int main(int argc, char ** argv) {
}
printf("processing files:\n");
for (auto & context_file : params.context_files) {
for (auto & context_file : retrieval_params.context_files) {
printf("%s\n", context_file.c_str());
}
std::vector<chunk> chunks;
for (auto & context_file : params.context_files) {
std::vector<chunk> file_chunk = chunk_file(context_file, params.chunk_size, params.chunk_separator);
for (auto & context_file : retrieval_params.context_files) {
std::vector<chunk> file_chunk = chunk_file(context_file, retrieval_params.chunk_size, retrieval_params.chunk_separator);
chunks.insert(chunks.end(), file_chunk.begin(), file_chunk.end());
}
printf("Number of chunks: %ld\n", chunks.size());