define retrieval-only parameters in retrieval.cpp
This commit is contained in:
parent
f9dc033797
commit
56b7db971e
3 changed files with 84 additions and 54 deletions
|
@ -154,7 +154,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param) {
|
bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param) {
|
||||||
std::string arg = argv[i];
|
std::string arg = argv[i];
|
||||||
llama_sampling_params& sparams = params.sparams;
|
llama_sampling_params& sparams = params.sparams;
|
||||||
|
|
||||||
|
@ -276,43 +276,6 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "--context-files") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
while(true) {
|
|
||||||
std::ifstream file(argv[i]);
|
|
||||||
if (!file) {
|
|
||||||
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
// store the external file name in params
|
|
||||||
params.context_files.push_back(argv[i]);
|
|
||||||
if (i + 1 >= argc || argv[i + 1][0] == '-') {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (arg == "--chunk-size") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
params.chunk_size = std::stoi(argv[i]);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (arg == "--chunk-separator") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
params.chunk_separator = argv[i];
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (arg == "-n" || arg == "--n-predict") {
|
if (arg == "-n" || arg == "--n-predict") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -1319,11 +1282,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
printf(" prompt file to start generation.\n");
|
printf(" prompt file to start generation.\n");
|
||||||
printf(" -bf FNAME, --binary-file FNAME\n");
|
printf(" -bf FNAME, --binary-file FNAME\n");
|
||||||
printf(" binary file containing multiple choice tasks.\n");
|
printf(" binary file containing multiple choice tasks.\n");
|
||||||
printf(" --context-files FNAME1 FNAME2...\n");
|
|
||||||
printf(" files containing context to embed.\n");
|
|
||||||
printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size);
|
|
||||||
printf(" --chunk-separator STRING\n");
|
|
||||||
printf(" string to separate chunks (default: \"\\n\")\n");
|
|
||||||
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
|
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
|
||||||
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
|
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
|
||||||
printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch);
|
printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch);
|
||||||
|
|
|
@ -79,9 +79,6 @@ struct gpt_params {
|
||||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||||
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
||||||
std::vector<std::string> context_files;// context files to embed
|
|
||||||
int32_t chunk_size = 64; // chunk size for context embedding
|
|
||||||
std::string chunk_separator = "\n"; // chunk separator for context embedding
|
|
||||||
|
|
||||||
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
|
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
|
||||||
|
|
||||||
|
@ -170,6 +167,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||||
|
|
||||||
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
|
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
|
||||||
|
|
||||||
|
bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param);
|
||||||
|
|
||||||
std::string get_system_info(const gpt_params & params);
|
std::string get_system_info(const gpt_params & params);
|
||||||
|
|
||||||
std::string gpt_random_prompt(std::mt19937 & rng);
|
std::string gpt_random_prompt(std::mt19937 & rng);
|
||||||
|
|
|
@ -4,6 +4,80 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
|
struct retrieval_params {
|
||||||
|
std::vector<std::string> context_files;// context files to embed
|
||||||
|
int32_t chunk_size = 64; // chunk size for context embedding
|
||||||
|
std::string chunk_separator = "\n"; // chunk separator for context embedding
|
||||||
|
};
|
||||||
|
|
||||||
|
static void retrieval_params_print_usage(int argc, char** argv, gpt_params & gpt_params, retrieval_params & params) {
|
||||||
|
fprintf(stderr, "usage: retrieval [options]\n");
|
||||||
|
fprintf(stderr, "options:\n");
|
||||||
|
printf(" --context-files FNAME1 FNAME2...\n");
|
||||||
|
printf(" files containing context to embed.\n");
|
||||||
|
printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size);
|
||||||
|
printf(" --chunk-separator STRING\n");
|
||||||
|
printf(" string to separate chunks (default: \"\\n\")\n");
|
||||||
|
gpt_print_usage(argc, argv, gpt_params);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void retrieval_params_parse(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & retrieval_params) {
|
||||||
|
int i = 1;
|
||||||
|
std::string arg;
|
||||||
|
while(i < argc) {
|
||||||
|
arg = argv[i];
|
||||||
|
bool invalid_gpt_param = false;
|
||||||
|
if(gpt_params_find_arg(argc, argv, gpt_params, i, invalid_gpt_param)) {
|
||||||
|
if (invalid_gpt_param) {
|
||||||
|
fprintf(stderr, "error: invalid argument: %s\n", arg.c_str());
|
||||||
|
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
// option was parsed by gpt_params_find_arg
|
||||||
|
} else if (arg == "--context-files") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
fprintf(stderr, "error: missing argument for --context-files\n");
|
||||||
|
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
while(true) {
|
||||||
|
std::ifstream file(argv[i]);
|
||||||
|
if (!file) {
|
||||||
|
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
||||||
|
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
// store the external file name in params
|
||||||
|
retrieval_params.context_files.push_back(argv[i]);
|
||||||
|
if (i + 1 >= argc || argv[i + 1][0] == '-') {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
} else if (arg == "--chunk-size") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
fprintf(stderr, "error: missing argument for --chunk-size\n");
|
||||||
|
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
retrieval_params.chunk_size = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--chunk-separator") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
fprintf(stderr, "error: missing argument for --chunk-separator\n");
|
||||||
|
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
retrieval_params.chunk_separator = argv[i];
|
||||||
|
} else {
|
||||||
|
// unknown argument
|
||||||
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
|
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct chunk {
|
struct chunk {
|
||||||
// filename
|
// filename
|
||||||
std::string filename;
|
std::string filename;
|
||||||
|
@ -103,19 +177,18 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
retrieval_params retrieval_params;
|
||||||
|
|
||||||
if (!gpt_params_parse(argc, argv, params)) {
|
retrieval_params_parse(argc, argv, params, retrieval_params);
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For BERT models, batch size must be equal to ubatch size
|
// For BERT models, batch size must be equal to ubatch size
|
||||||
params.n_ubatch = params.n_batch;
|
params.n_ubatch = params.n_batch;
|
||||||
|
|
||||||
if (params.chunk_size <= 0) {
|
if (retrieval_params.chunk_size <= 0) {
|
||||||
fprintf(stderr, "chunk_size must be positive\n");
|
fprintf(stderr, "chunk_size must be positive\n");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
if (params.context_files.empty()) {
|
if (retrieval_params.context_files.empty()) {
|
||||||
fprintf(stderr, "context_files must be specified\n");
|
fprintf(stderr, "context_files must be specified\n");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -128,13 +201,13 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("processing files:\n");
|
printf("processing files:\n");
|
||||||
for (auto & context_file : params.context_files) {
|
for (auto & context_file : retrieval_params.context_files) {
|
||||||
printf("%s\n", context_file.c_str());
|
printf("%s\n", context_file.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<chunk> chunks;
|
std::vector<chunk> chunks;
|
||||||
for (auto & context_file : params.context_files) {
|
for (auto & context_file : retrieval_params.context_files) {
|
||||||
std::vector<chunk> file_chunk = chunk_file(context_file, params.chunk_size, params.chunk_separator);
|
std::vector<chunk> file_chunk = chunk_file(context_file, retrieval_params.chunk_size, retrieval_params.chunk_separator);
|
||||||
chunks.insert(chunks.end(), file_chunk.begin(), file_chunk.end());
|
chunks.insert(chunks.end(), file_chunk.begin(), file_chunk.end());
|
||||||
}
|
}
|
||||||
printf("Number of chunks: %ld\n", chunks.size());
|
printf("Number of chunks: %ld\n", chunks.size());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue