retrieval : minor
This commit is contained in:
parent
eb5fd2d8c5
commit
f6be52d793
3 changed files with 40 additions and 31 deletions
|
@ -1,25 +1,31 @@
|
||||||
# llama.cpp/examples/retrieval
|
# llama.cpp/examples/retrieval
|
||||||
|
|
||||||
Demonstration of simple retrieval technique based on cosin similarity
|
Demonstration of simple retrieval technique based on cosine similarity
|
||||||
|
|
||||||
More info:
|
More info:
|
||||||
https://github.com/ggerganov/llama.cpp/pull/6193
|
https://github.com/ggerganov/llama.cpp/pull/6193
|
||||||
|
|
||||||
### How to use
|
### How to use
|
||||||
|
|
||||||
`retieval.cpp` has parameters of its own:
|
`retieval.cpp` has parameters of its own:
|
||||||
- `--context-file`: file to be embedded - state this option multiple times to embed multiple files
|
- `--context-file`: file to be embedded - state this option multiple times to embed multiple files
|
||||||
- `--chunk-size`: minimum size of each text chunk to be embedded
|
- `--chunk-size`: minimum size of each text chunk to be embedded
|
||||||
- `--chunk-separator`: STRING to divide chunks by. newline by default
|
- `--chunk-separator`: STRING to divide chunks by. newline by default
|
||||||
|
|
||||||
`retrieval` example can be tested as below
|
`retrieval` example can be tested as follows:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
make -j && ./retrieval --model ./models/bge-base-en-v1.5-f16.gguf --top-k 3 --context-file README.md --context-file License --chunk-size 100 --chunk-separator .
|
make -j && ./retrieval --model ./models/bge-base-en-v1.5-f16.gguf --top-k 3 --context-file README.md --context-file License --chunk-size 100 --chunk-separator .
|
||||||
```
|
```
|
||||||
which chunks & embeds all given files and starts a loop requesting query inputs:
|
|
||||||
|
This chunks and embeds all given files and starts a loop requesting query inputs:
|
||||||
|
|
||||||
```
|
```
|
||||||
Enter query:
|
Enter query:
|
||||||
```
|
```
|
||||||
and on query input, top k chunks are shown along with file name, chunk position within file and original text
|
|
||||||
|
On each query input, top k chunks are shown along with file name, chunk position within file and original text:
|
||||||
|
|
||||||
```
|
```
|
||||||
Enter query: describe the mit license
|
Enter query: describe the mit license
|
||||||
batch_decode: n_tokens = 6, n_seq = 1
|
batch_decode: n_tokens = 6, n_seq = 1
|
||||||
|
|
|
@ -5,15 +5,15 @@
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
struct retrieval_params {
|
struct retrieval_params {
|
||||||
std::vector<std::string> context_files;// context files to embed
|
std::vector<std::string> context_files; // context files to embed
|
||||||
int32_t chunk_size = 64; // chunk size for context embedding
|
int32_t chunk_size = 64; // chunk size for context embedding
|
||||||
std::string chunk_separator = "\n"; // chunk separator for context embedding
|
std::string chunk_separator = "\n"; // chunk separator for context embedding
|
||||||
};
|
};
|
||||||
|
|
||||||
static void retrieval_params_print_usage(int argc, char** argv, gpt_params & gpt_params, retrieval_params & params) {
|
static void retrieval_params_print_usage(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & params) {
|
||||||
gpt_print_usage(argc, argv, gpt_params);
|
gpt_print_usage(argc, argv, gpt_params);
|
||||||
printf("retrieval options:\n");
|
printf("retrieval options:\n");
|
||||||
printf(" --context-file FNAME files containing context to embed.\n");
|
printf(" --context-file FNAME file containing context to embed.\n");
|
||||||
printf(" specify multiple files by providing --context-file option multiple times.\n");
|
printf(" specify multiple files by providing --context-file option multiple times.\n");
|
||||||
printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size);
|
printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size);
|
||||||
printf(" --chunk-separator STRING\n");
|
printf(" --chunk-separator STRING\n");
|
||||||
|
@ -24,7 +24,7 @@ static void retrieval_params_print_usage(int argc, char** argv, gpt_params & gpt
|
||||||
static void retrieval_params_parse(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & retrieval_params) {
|
static void retrieval_params_parse(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & retrieval_params) {
|
||||||
int i = 1;
|
int i = 1;
|
||||||
std::string arg;
|
std::string arg;
|
||||||
while(i < argc) {
|
while (i < argc) {
|
||||||
arg = argv[i];
|
arg = argv[i];
|
||||||
bool invalid_gpt_param = false;
|
bool invalid_gpt_param = false;
|
||||||
if(gpt_params_find_arg(argc, argv, argv[i], gpt_params, i, invalid_gpt_param)) {
|
if(gpt_params_find_arg(argc, argv, argv[i], gpt_params, i, invalid_gpt_param)) {
|
||||||
|
@ -36,7 +36,7 @@ static void retrieval_params_parse(int argc, char ** argv, gpt_params & gpt_para
|
||||||
// option was parsed by gpt_params_find_arg
|
// option was parsed by gpt_params_find_arg
|
||||||
} else if (arg == "--context-file") {
|
} else if (arg == "--context-file") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
fprintf(stderr, "error: missing argument for --context-files\n");
|
fprintf(stderr, "error: missing argument for --context-file\n");
|
||||||
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
|
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
@ -87,7 +87,7 @@ struct chunk {
|
||||||
|
|
||||||
// chunk file data to chunks of size >= chunk_size
|
// chunk file data to chunks of size >= chunk_size
|
||||||
// chunk_separator is the separator between chunks
|
// chunk_separator is the separator between chunks
|
||||||
static std::vector<chunk> chunk_file(const std::string filename, int chunk_size, std::string chunk_separator) {
|
static std::vector<chunk> chunk_file(const std::string & filename, int chunk_size, const std::string & chunk_separator) {
|
||||||
std::vector<chunk> chunks;
|
std::vector<chunk> chunks;
|
||||||
std::ifstream f(filename.c_str());
|
std::ifstream f(filename.c_str());
|
||||||
|
|
||||||
|
@ -312,17 +312,20 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
|
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
|
||||||
batch_add_seq(query_batch, query_tokens, 0);
|
batch_add_seq(query_batch, query_tokens, 0);
|
||||||
|
|
||||||
std::vector<float> query_emb(n_embd, 0);
|
std::vector<float> query_emb(n_embd, 0);
|
||||||
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
|
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
|
||||||
std::vector<float> query_embedding(query_emb.data(), query_emb.data() + n_embd);
|
|
||||||
llama_batch_clear(query_batch);
|
llama_batch_clear(query_batch);
|
||||||
|
|
||||||
// similarities tuple vector (chunk index, similarity)
|
// compute cosine similarities
|
||||||
|
{
|
||||||
std::vector<std::pair<int, float>> similarities;
|
std::vector<std::pair<int, float>> similarities;
|
||||||
for (int i = 0; i < n_chunks; i++) {
|
for (int i = 0; i < n_chunks; i++) {
|
||||||
float sim = llama_embd_similarity_cos(chunks[i].embedding.data(), query_embedding.data(), n_embd);
|
float sim = llama_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd);
|
||||||
similarities.push_back(std::make_pair(i, sim));
|
similarities.push_back(std::make_pair(i, sim));
|
||||||
}
|
}
|
||||||
|
|
||||||
// sort similarities
|
// sort similarities
|
||||||
std::sort(similarities.begin(), similarities.end(), [](const std::pair<int, float> & a, const std::pair<int, float> & b) {
|
std::sort(similarities.begin(), similarities.end(), [](const std::pair<int, float> & a, const std::pair<int, float> & b) {
|
||||||
return a.second > b.second;
|
return a.second > b.second;
|
||||||
|
@ -336,7 +339,7 @@ int main(int argc, char ** argv) {
|
||||||
printf("textdata:\n%s\n", chunks[similarities[i].first].textdata.c_str());
|
printf("textdata:\n%s\n", chunks[similarities[i].first].textdata.c_str());
|
||||||
printf("--------------------\n");
|
printf("--------------------\n");
|
||||||
}
|
}
|
||||||
similarities.clear();
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// clean up
|
// clean up
|
||||||
|
|
BIN
retrieval
BIN
retrieval
Binary file not shown.
Loading…
Add table
Add a link
Reference in a new issue