From 56b7db971e1851964f026a9977d3fbd2e7f46ca1 Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Sun, 24 Mar 2024 20:48:38 +0900 Subject: [PATCH] define retrieval-only parameters in retrieval.cpp --- common/common.cpp | 44 +--------------- common/common.h | 5 +- examples/retrieval/retrieval.cpp | 89 +++++++++++++++++++++++++++++--- 3 files changed, 84 insertions(+), 54 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 290bfff80..ad529c585 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -154,7 +154,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { return result; } -static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param) { +bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param) { std::string arg = argv[i]; llama_sampling_params& sparams = params.sparams; @@ -276,43 +276,6 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int } return true; } - if (arg == "--context-files") { - if (++i >= argc) { - invalid_param = true; - return true; - } - while(true) { - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - break; - } - // store the external file name in params - params.context_files.push_back(argv[i]); - if (i + 1 >= argc || argv[i + 1][0] == '-') { - break; - } - i++; - } - return true; - } - if (arg == "--chunk-size") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.chunk_size = std::stoi(argv[i]); - return true; - } - if (arg == "--chunk-separator") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.chunk_separator = argv[i]; - return true; - } if (arg == "-n" || arg == "--n-predict") { if (++i >= argc) { invalid_param = true; @@ -1319,11 +1282,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" prompt file to start generation.\n"); printf(" -bf FNAME, --binary-file FNAME\n"); printf(" binary file containing multiple choice tasks.\n"); - printf(" --context-files FNAME1 FNAME2...\n"); - printf(" files containing context to embed.\n"); - printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size); - printf(" --chunk-separator STRING\n"); - printf(" string to separate chunks (default: \"\\n\")\n"); printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch); diff --git a/common/common.h b/common/common.h index 31b971069..046a994ce 100644 --- a/common/common.h +++ b/common/common.h @@ -79,9 +79,6 @@ struct gpt_params { float yarn_beta_slow = 1.0f; // YaRN high correction dim int32_t yarn_orig_ctx = 0; // YaRN original context length float defrag_thold = -1.0f; // KV cache defragmentation threshold - std::vector context_files;// context files to embed - int32_t chunk_size = 64; // chunk size for context embedding - std::string chunk_separator = "\n"; // chunk separator for context embedding ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; @@ -170,6 +167,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params); void gpt_print_usage(int argc, char ** argv, const gpt_params & params); +bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param); + std::string get_system_info(const gpt_params & params); std::string gpt_random_prompt(std::mt19937 & rng); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 467e2f027..b1e2b5bbf 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -4,6 +4,80 @@ #include #include +struct retrieval_params { + std::vector context_files;// context files to embed + int32_t chunk_size = 64; // chunk size for context embedding + std::string chunk_separator = "\n"; // chunk separator for context embedding +}; + +static void retrieval_params_print_usage(int argc, char** argv, gpt_params & gpt_params, retrieval_params & params) { + fprintf(stderr, "usage: retrieval [options]\n"); + fprintf(stderr, "options:\n"); + printf(" --context-files FNAME1 FNAME2...\n"); + printf(" files containing context to embed.\n"); + printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size); + printf(" --chunk-separator STRING\n"); + printf(" string to separate chunks (default: \"\\n\")\n"); + gpt_print_usage(argc, argv, gpt_params); +} + +static void retrieval_params_parse(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & retrieval_params) { + int i = 1; + std::string arg; + while(i < argc) { + arg = argv[i]; + bool invalid_gpt_param = false; + if(gpt_params_find_arg(argc, argv, gpt_params, i, invalid_gpt_param)) { + if (invalid_gpt_param) { + fprintf(stderr, "error: invalid argument: %s\n", arg.c_str()); + retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); + exit(1); + } + // option was parsed by gpt_params_find_arg + } else if (arg == "--context-files") { + if (++i >= argc) { + fprintf(stderr, "error: missing argument for --context-files\n"); + retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); + exit(1); + } + while(true) { + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); + exit(1); + } + // store the external file name in params + retrieval_params.context_files.push_back(argv[i]); + if (i + 1 >= argc || argv[i + 1][0] == '-') { + break; + } + i++; + } + } else if (arg == "--chunk-size") { + if (++i >= argc) { + fprintf(stderr, "error: missing argument for --chunk-size\n"); + retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); + exit(1); + } + retrieval_params.chunk_size = std::stoi(argv[i]); + } else if (arg == "--chunk-separator") { + if (++i >= argc) { + fprintf(stderr, "error: missing argument for --chunk-separator\n"); + retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); + exit(1); + } + retrieval_params.chunk_separator = argv[i]; + } else { + // unknown argument + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); + exit(1); + } + i++; + } +} + struct chunk { // filename std::string filename; @@ -103,19 +177,18 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu int main(int argc, char ** argv) { gpt_params params; + retrieval_params retrieval_params; - if (!gpt_params_parse(argc, argv, params)) { - return 1; - } + retrieval_params_parse(argc, argv, params, retrieval_params); // For BERT models, batch size must be equal to ubatch size params.n_ubatch = params.n_batch; - if (params.chunk_size <= 0) { + if (retrieval_params.chunk_size <= 0) { fprintf(stderr, "chunk_size must be positive\n"); return 1; } - if (params.context_files.empty()) { + if (retrieval_params.context_files.empty()) { fprintf(stderr, "context_files must be specified\n"); return 1; } @@ -128,13 +201,13 @@ int main(int argc, char ** argv) { } printf("processing files:\n"); - for (auto & context_file : params.context_files) { + for (auto & context_file : retrieval_params.context_files) { printf("%s\n", context_file.c_str()); } std::vector chunks; - for (auto & context_file : params.context_files) { - std::vector file_chunk = chunk_file(context_file, params.chunk_size, params.chunk_separator); + for (auto & context_file : retrieval_params.context_files) { + std::vector file_chunk = chunk_file(context_file, retrieval_params.chunk_size, retrieval_params.chunk_separator); chunks.insert(chunks.end(), file_chunk.begin(), file_chunk.end()); } printf("Number of chunks: %ld\n", chunks.size());