store similarities in separate vector

This commit is contained in:
Minsoo Cheong 2024-03-22 15:03:34 +09:00
parent d33a015749
commit 208e1f03c3

View file

@ -15,8 +15,6 @@ struct chunk {
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
// embedding // embedding
std::vector<float> embedding; std::vector<float> embedding;
// cosin similarity
float similarity;
}; };
// chunk file data to chunks of size >= chunk_size // chunk file data to chunks of size >= chunk_size
@ -233,6 +231,8 @@ int main(int argc, char ** argv) {
// save embeddings to chunks // save embeddings to chunks
for (int i = 0; i < n_chunks; i++) { for (int i = 0; i < n_chunks; i++) {
chunks[i].embedding = std::vector<float>(emb + i * n_embd, emb + (i + 1) * n_embd); chunks[i].embedding = std::vector<float>(emb + i * n_embd, emb + (i + 1) * n_embd);
// clear tokens as they are no longer needed
chunks[i].tokens.clear();
} }
// start loop, receive query and return top k similar chunks based on cosine similarity // start loop, receive query and return top k similar chunks based on cosine similarity
@ -250,21 +250,26 @@ int main(int argc, char ** argv) {
delete[] query_emb; delete[] query_emb;
llama_batch_clear(query_batch); llama_batch_clear(query_batch);
// similarities tuple vector (chunk index, similarity)
std::vector<std::pair<int, float>> similarities;
for (int i = 0; i < n_chunks; i++) { for (int i = 0; i < n_chunks; i++) {
float similarity = llama_embd_similarity_cos(chunks[i].embedding.data(), query_embedding.data(), n_embd); float sim = llama_embd_similarity_cos(chunks[i].embedding.data(), query_embedding.data(), n_embd);
chunks[i].similarity = similarity; similarities.push_back(std::make_pair(i, sim));
} }
std::sort(chunks.begin(), chunks.end(), [](chunk & a, chunk & b) { // sort similarities
return a.similarity > b.similarity; std::sort(similarities.begin(), similarities.end(), [](const std::pair<int, float> & a, const std::pair<int, float> & b) {
return a.second > b.second;
}); });
printf("Top %d similar chunks:\n", params.sparams.top_k); printf("Top %d similar chunks:\n", params.sparams.top_k);
for (int i = 0; i < std::min(params.sparams.top_k, (int) chunks.size()); i++) { for (int i = 0; i < std::min(params.sparams.top_k, (int) chunks.size()); i++) {
printf("filename: %s\n", chunks[i].filename.c_str()); printf("filename: %s\n", chunks[similarities[i].first].filename.c_str());
printf("filepos: %lld\n", (long long int) chunks[i].filepos); printf("filepos: %lld\n", (long long int) chunks[similarities[i].first].filepos);
printf("similarity: %f\n", chunks[i].similarity); printf("similarity: %f\n", similarities[i].second);
printf("textdata:\n%s\n", chunks[i].textdata.c_str()); printf("textdata:\n%s\n", chunks[similarities[i].first].textdata.c_str());
printf("--------------------\n"); printf("--------------------\n");
} }
similarities.clear();
} }
// clean up // clean up