embedding : print cosine similarity (#899)
This commit is contained in:
parent
19885d205e
commit
0fd6c1f015
4 changed files with 36 additions and 25 deletions
|
@ -6,22 +6,6 @@
|
|||
|
||||
// #define GRIT_DEBUG
|
||||
|
||||
static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) {
|
||||
float dot = 0.0f;
|
||||
for (uint64_t i = 0; i < v1.size(); ++i) {
|
||||
dot += v1[i] * v2[i];
|
||||
}
|
||||
return dot;
|
||||
}
|
||||
|
||||
static float norm(const std::vector<float> & v) {
|
||||
return std::sqrt(dot_product(v, v));
|
||||
}
|
||||
|
||||
static float cosine_similarity(const std::vector<float> & v1, const std::vector<float> & v2) {
|
||||
return dot_product(v1, v2) / (norm(v1) * norm(v2));
|
||||
}
|
||||
|
||||
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
|
||||
std::vector<std::vector<float>> result;
|
||||
|
||||
|
@ -203,10 +187,12 @@ int main(int argc, char * argv[]) {
|
|||
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
|
||||
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
|
||||
|
||||
const float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]);
|
||||
const float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]);
|
||||
const float cosine_sim_q1_d0 = cosine_similarity(q_rep[1], d_rep[0]);
|
||||
const float cosine_sim_q1_d1 = cosine_similarity(q_rep[1], d_rep[1]);
|
||||
const int n_embd = llama_n_embd(mdl);
|
||||
|
||||
const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
|
||||
const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
|
||||
const float cosine_sim_q1_d0 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd);
|
||||
const float cosine_sim_q1_d1 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd);
|
||||
|
||||
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
|
||||
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue