gritlm example using llama_get_embeddings_mean_pooled
This commit is contained in:
parent
2a24f71497
commit
798c29d6b9
1 changed files with 4 additions and 17 deletions
|
@ -54,24 +54,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
|
||||
// allocate embedding output
|
||||
std::vector<float> emb_unorm(n_embd, 0.0f);
|
||||
float * raw_embed = emb_unorm.data();
|
||||
|
||||
// sum up all token embeddings
|
||||
for (int32_t k = n_inst; k < n_toks; k++) {
|
||||
float * emb = llama_get_embeddings_ith(ctx, k);
|
||||
for (uint64_t j = 0; j < n_embd; j++) {
|
||||
emb_unorm[j] += emb[j];
|
||||
}
|
||||
}
|
||||
|
||||
// divide by number of tokens (mean pooling)
|
||||
{
|
||||
const uint64_t n_sent = n_toks - n_inst;
|
||||
|
||||
for (uint64_t j = 0; j < n_embd; j++) {
|
||||
emb_unorm[j] /= n_sent;
|
||||
}
|
||||
}
|
||||
|
||||
// retrieve summed up token embeddings, skipping instruction tokens,
|
||||
// and writes the result to raw_embed
|
||||
llama_get_embeddings_mean_pooled(ctx, n_inst, n_toks, raw_embed);
|
||||
std::vector<float> emb_norm(emb_unorm.size());
|
||||
llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
|
||||
result.push_back(emb_norm);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue