retrieval : migrate to gpt_params

This commit is contained in:
Georgi Gerganov 2024-06-04 12:42:02 +03:00
parent a149eed043
commit c4b6b83811
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 60 additions and 75 deletions

View file

@ -1498,6 +1498,36 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.n_pl.insert(params.n_pl.end(), p.begin(), p.end());
return true;
}
if (arg == "--context-file") {
if (++i >= argc) {
invalid_param = true;
return true;
}
std::ifstream file(argv[i], std::ios::binary);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
return true;
}
params.context_files.push_back(argv[i]);
return true;
}
if (arg == "--chunk-size") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.chunk_size = std::stoi(argv[i]);
return true;
}
if (arg == "--chunk-separator") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.chunk_separator = argv[i];
return true;
}
#ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters
if (log_param_single_parse(argv[i])) {
@ -1754,6 +1784,12 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
options.push_back({ "retrieval" });
options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" });
options.push_back({ "retrieval", " --chunk-size N", "minimum length of embedded text chunks (default: %d)", params.chunk_size });
options.push_back({ "retrieval", " --chunk-separator STRING",
"separator between chunks (default: '%s')", params.chunk_separator.c_str() });
options.push_back({ "bench" });
options.push_back({ "bench", "-pps", "is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false" });
options.push_back({ "bench", "-npp n0,n1,...", "number of prompt tokens" });

View file

@ -210,6 +210,13 @@ struct gpt_params {
std::vector<int32_t> n_pp;
std::vector<int32_t> n_tg;
std::vector<int32_t> n_pl;
// retrieval params
std::vector<std::string> context_files; // context files to embed
int32_t chunk_size = 64; // chunk size for context embedding
std::string chunk_separator = "\n"; // chunk separator for context embedding
};
void gpt_params_handle_model_default(gpt_params & params);

View file

@ -4,72 +4,12 @@
#include <algorithm>
#include <fstream>
struct retrieval_params {
std::vector<std::string> context_files; // context files to embed
int32_t chunk_size = 64; // chunk size for context embedding
std::string chunk_separator = "\n"; // chunk separator for context embedding
};
static void print_usage(int argc, char ** argv, const gpt_params & params) {
gpt_params_print_usage(argc, argv, params);
static void retrieval_params_print_usage(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & params) {
gpt_params_print_usage(argc, argv, gpt_params);
printf("retrieval options:\n");
printf(" --context-file FNAME file containing context to embed.\n");
printf(" specify multiple files by providing --context-file option multiple times.\n");
printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size);
printf(" --chunk-separator STRING\n");
printf(" string to separate chunks (default: \"\\n\")\n");
printf("\n");
}
static void retrieval_params_parse(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & retrieval_params) {
int i = 1;
std::string arg;
while (i < argc) {
arg = argv[i];
bool invalid_gpt_param = false;
if (gpt_params_find_arg(argc, argv, argv[i], gpt_params, i, invalid_gpt_param)) {
if (invalid_gpt_param) {
fprintf(stderr, "error: invalid argument: %s\n", arg.c_str());
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
// option was parsed by gpt_params_find_arg
} else if (arg == "--context-file") {
if (++i >= argc) {
fprintf(stderr, "error: missing argument for --context-file\n");
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
// store the external file name in params
retrieval_params.context_files.push_back(argv[i]);
} else if (arg == "--chunk-size") {
if (++i >= argc) {
fprintf(stderr, "error: missing argument for --chunk-size\n");
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
retrieval_params.chunk_size = std::stoi(argv[i]);
} else if (arg == "--chunk-separator") {
if (++i >= argc) {
fprintf(stderr, "error: missing argument for --chunk-separator\n");
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
retrieval_params.chunk_separator = argv[i];
} else {
// unknown argument
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
i++;
}
LOG_TEE("\nexample usage:\n");
LOG_TEE("\n %s --model ./models/bge-base-en-v1.5-f16.gguf --top-k 3 --context-file README.md --context-file License --chunk-size 100 --chunk-separator .\n", argv[0]);
LOG_TEE("\n");
}
struct chunk {
@ -171,33 +111,35 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
int main(int argc, char ** argv) {
gpt_params params;
retrieval_params retrieval_params;
retrieval_params_parse(argc, argv, params, retrieval_params);
if (!gpt_params_parse(argc, argv, params)) {
print_usage(argc, argv, params);
return 1;
}
// For BERT models, batch size must be equal to ubatch size
params.n_ubatch = params.n_batch;
params.embedding = true;
if (retrieval_params.chunk_size <= 0) {
if (params.chunk_size <= 0) {
fprintf(stderr, "chunk_size must be positive\n");
return 1;
}
if (retrieval_params.context_files.empty()) {
if (params.context_files.empty()) {
fprintf(stderr, "context_files must be specified\n");
return 1;
}
params.embedding = true;
print_build_info();
printf("processing files:\n");
for (auto & context_file : retrieval_params.context_files) {
for (auto & context_file : params.context_files) {
printf("%s\n", context_file.c_str());
}
std::vector<chunk> chunks;
for (auto & context_file : retrieval_params.context_files) {
std::vector<chunk> file_chunk = chunk_file(context_file, retrieval_params.chunk_size, retrieval_params.chunk_separator);
for (auto & context_file : params.context_files) {
std::vector<chunk> file_chunk = chunk_file(context_file, params.chunk_size, params.chunk_separator);
chunks.insert(chunks.end(), file_chunk.begin(), file_chunk.end());
}
printf("Number of chunks: %ld\n", chunks.size());
@ -242,7 +184,7 @@ int main(int argc, char ** argv) {
return 1;
}
// add eos if not present
if (inp.empty() || inp.back() != llama_token_eos(model)) {
if (llama_token_eos(model) >= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) {
inp.push_back(llama_token_eos(model));
}
chunk.tokens = inp;

View file

@ -1238,7 +1238,7 @@ struct server_context {
}
json get_formated_generation(const server_slot & slot) const {
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
std::vector<std::string> samplers_sequence;