diff --git a/examples/retrieval/README.md b/examples/retrieval/README.md index 7a31d9a8c..2b2595c46 100644 --- a/examples/retrieval/README.md +++ b/examples/retrieval/README.md @@ -1,25 +1,31 @@ # llama.cpp/examples/retrieval -Demonstration of simple retrieval technique based on cosin similarity +Demonstration of simple retrieval technique based on cosine similarity More info: https://github.com/ggerganov/llama.cpp/pull/6193 ### How to use + `retieval.cpp` has parameters of its own: - `--context-file`: file to be embedded - state this option multiple times to embed multiple files - `--chunk-size`: minimum size of each text chunk to be embedded - `--chunk-separator`: STRING to divide chunks by. newline by default -`retrieval` example can be tested as below +`retrieval` example can be tested as follows: + ```bash make -j && ./retrieval --model ./models/bge-base-en-v1.5-f16.gguf --top-k 3 --context-file README.md --context-file License --chunk-size 100 --chunk-separator . ``` -which chunks & embeds all given files and starts a loop requesting query inputs: + +This chunks and embeds all given files and starts a loop requesting query inputs: + ``` Enter query: ``` -and on query input, top k chunks are shown along with file name, chunk position within file and original text + +On each query input, top k chunks are shown along with file name, chunk position within file and original text: + ``` Enter query: describe the mit license batch_decode: n_tokens = 6, n_seq = 1 diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 335fd0700..5ba71e76a 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -5,15 +5,15 @@ #include struct retrieval_params { - std::vector context_files;// context files to embed - int32_t chunk_size = 64; // chunk size for context embedding - std::string chunk_separator = "\n"; // chunk separator for context embedding + std::vector context_files; // context files to embed + int32_t chunk_size = 64; // chunk size for context embedding + std::string chunk_separator = "\n"; // chunk separator for context embedding }; -static void retrieval_params_print_usage(int argc, char** argv, gpt_params & gpt_params, retrieval_params & params) { +static void retrieval_params_print_usage(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & params) { gpt_print_usage(argc, argv, gpt_params); printf("retrieval options:\n"); - printf(" --context-file FNAME files containing context to embed.\n"); + printf(" --context-file FNAME file containing context to embed.\n"); printf(" specify multiple files by providing --context-file option multiple times.\n"); printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size); printf(" --chunk-separator STRING\n"); @@ -24,7 +24,7 @@ static void retrieval_params_print_usage(int argc, char** argv, gpt_params & gpt static void retrieval_params_parse(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & retrieval_params) { int i = 1; std::string arg; - while(i < argc) { + while (i < argc) { arg = argv[i]; bool invalid_gpt_param = false; if(gpt_params_find_arg(argc, argv, argv[i], gpt_params, i, invalid_gpt_param)) { @@ -36,7 +36,7 @@ static void retrieval_params_parse(int argc, char ** argv, gpt_params & gpt_para // option was parsed by gpt_params_find_arg } else if (arg == "--context-file") { if (++i >= argc) { - fprintf(stderr, "error: missing argument for --context-files\n"); + fprintf(stderr, "error: missing argument for --context-file\n"); retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); exit(1); } @@ -87,7 +87,7 @@ struct chunk { // chunk file data to chunks of size >= chunk_size // chunk_separator is the separator between chunks -static std::vector chunk_file(const std::string filename, int chunk_size, std::string chunk_separator) { +static std::vector chunk_file(const std::string & filename, int chunk_size, const std::string & chunk_separator) { std::vector chunks; std::ifstream f(filename.c_str()); @@ -312,31 +312,34 @@ int main(int argc, char ** argv) { struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); batch_add_seq(query_batch, query_tokens, 0); + std::vector query_emb(n_embd, 0); batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); - std::vector query_embedding(query_emb.data(), query_emb.data() + n_embd); + llama_batch_clear(query_batch); - // similarities tuple vector (chunk index, similarity) - std::vector> similarities; - for (int i = 0; i < n_chunks; i++) { - float sim = llama_embd_similarity_cos(chunks[i].embedding.data(), query_embedding.data(), n_embd); - similarities.push_back(std::make_pair(i, sim)); - } - // sort similarities - std::sort(similarities.begin(), similarities.end(), [](const std::pair & a, const std::pair & b) { - return a.second > b.second; - }); + // compute cosine similarities + { + std::vector> similarities; + for (int i = 0; i < n_chunks; i++) { + float sim = llama_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd); + similarities.push_back(std::make_pair(i, sim)); + } - printf("Top %d similar chunks:\n", params.sparams.top_k); - for (int i = 0; i < std::min(params.sparams.top_k, (int) chunks.size()); i++) { - printf("filename: %s\n", chunks[similarities[i].first].filename.c_str()); - printf("filepos: %lld\n", (long long int) chunks[similarities[i].first].filepos); - printf("similarity: %f\n", similarities[i].second); - printf("textdata:\n%s\n", chunks[similarities[i].first].textdata.c_str()); - printf("--------------------\n"); + // sort similarities + std::sort(similarities.begin(), similarities.end(), [](const std::pair & a, const std::pair & b) { + return a.second > b.second; + }); + + printf("Top %d similar chunks:\n", params.sparams.top_k); + for (int i = 0; i < std::min(params.sparams.top_k, (int) chunks.size()); i++) { + printf("filename: %s\n", chunks[similarities[i].first].filename.c_str()); + printf("filepos: %lld\n", (long long int) chunks[similarities[i].first].filepos); + printf("similarity: %f\n", similarities[i].second); + printf("textdata:\n%s\n", chunks[similarities[i].first].textdata.c_str()); + printf("--------------------\n"); + } } - similarities.clear(); } // clean up diff --git a/retrieval b/retrieval deleted file mode 100755 index dd31789f8..000000000 Binary files a/retrieval and /dev/null differ