embedding : print cosine similarity (#899)

This commit is contained in:
Georgi Gerganov 2024-03-14 10:12:29 +02:00
parent 19885d205e
commit 0fd6c1f015
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 36 additions and 25 deletions

View file

@ -168,14 +168,25 @@ int main(int argc, char ** argv) {
batch_decode(ctx, batch, out, s, n_embd);
// print first 3 embeddings
fprintf(stdout, "\n");
for (int j = 0; j < std::min(3, n_prompts); j++) {
fprintf(stderr, "embedding %d: ", j);
for (int i = 0; i < n_embd; i++) {
fprintf(stderr, "%f ", emb[j * n_embd + i]);
fprintf(stdout, "embedding %d: ", j);
for (int i = 0; i < std::min(16, n_embd); i++) {
fprintf(stdout, "%f ", emb[j * n_embd + i]);
}
fprintf(stderr, "\n\n");
fprintf(stdout, "\n");
}
// print cosine similarity matrix
fprintf(stdout, "\n");
printf("cosine similarity matrix:\n\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f ", sim);
}
fprintf(stdout, "\n");
}
fprintf(stderr, "\n");
// clean up
llama_print_timings(ctx);