diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 817d98f5c..7cad37401 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -15,8 +15,6 @@ struct chunk { std::vector tokens; // embedding std::vector embedding; - // cosin similarity - float similarity; }; // chunk file data to chunks of size >= chunk_size @@ -233,6 +231,8 @@ int main(int argc, char ** argv) { // save embeddings to chunks for (int i = 0; i < n_chunks; i++) { chunks[i].embedding = std::vector(emb + i * n_embd, emb + (i + 1) * n_embd); + // clear tokens as they are no longer needed + chunks[i].tokens.clear(); } // start loop, receive query and return top k similar chunks based on cosine similarity @@ -250,21 +250,26 @@ int main(int argc, char ** argv) { delete[] query_emb; llama_batch_clear(query_batch); + // similarities tuple vector (chunk index, similarity) + std::vector> similarities; for (int i = 0; i < n_chunks; i++) { - float similarity = llama_embd_similarity_cos(chunks[i].embedding.data(), query_embedding.data(), n_embd); - chunks[i].similarity = similarity; + float sim = llama_embd_similarity_cos(chunks[i].embedding.data(), query_embedding.data(), n_embd); + similarities.push_back(std::make_pair(i, sim)); } - std::sort(chunks.begin(), chunks.end(), [](chunk & a, chunk & b) { - return a.similarity > b.similarity; + // sort similarities + std::sort(similarities.begin(), similarities.end(), [](const std::pair & a, const std::pair & b) { + return a.second > b.second; }); + printf("Top %d similar chunks:\n", params.sparams.top_k); for (int i = 0; i < std::min(params.sparams.top_k, (int) chunks.size()); i++) { - printf("filename: %s\n", chunks[i].filename.c_str()); - printf("filepos: %lld\n", (long long int) chunks[i].filepos); - printf("similarity: %f\n", chunks[i].similarity); - printf("textdata:\n%s\n", chunks[i].textdata.c_str()); + printf("filename: %s\n", chunks[similarities[i].first].filename.c_str()); + printf("filepos: %lld\n", (long long int) chunks[similarities[i].first].filepos); + printf("similarity: %f\n", similarities[i].second); + printf("textdata:\n%s\n", chunks[similarities[i].first].textdata.c_str()); printf("--------------------\n"); } + similarities.clear(); } // clean up