gritlm example using llama_get_embeddings_mean_pooled
This commit is contained in:
parent
2a24f71497
commit
798c29d6b9
1 changed files with 4 additions and 17 deletions
|
@ -54,24 +54,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||||
|
|
||||||
// allocate embedding output
|
// allocate embedding output
|
||||||
std::vector<float> emb_unorm(n_embd, 0.0f);
|
std::vector<float> emb_unorm(n_embd, 0.0f);
|
||||||
|
float * raw_embed = emb_unorm.data();
|
||||||
|
|
||||||
// sum up all token embeddings
|
// retrieve summed up token embeddings, skipping instruction tokens,
|
||||||
for (int32_t k = n_inst; k < n_toks; k++) {
|
// and writes the result to raw_embed
|
||||||
float * emb = llama_get_embeddings_ith(ctx, k);
|
llama_get_embeddings_mean_pooled(ctx, n_inst, n_toks, raw_embed);
|
||||||
for (uint64_t j = 0; j < n_embd; j++) {
|
|
||||||
emb_unorm[j] += emb[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// divide by number of tokens (mean pooling)
|
|
||||||
{
|
|
||||||
const uint64_t n_sent = n_toks - n_inst;
|
|
||||||
|
|
||||||
for (uint64_t j = 0; j < n_embd; j++) {
|
|
||||||
emb_unorm[j] /= n_sent;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> emb_norm(emb_unorm.size());
|
std::vector<float> emb_norm(emb_unorm.size());
|
||||||
llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
|
llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
|
||||||
result.push_back(emb_norm);
|
result.push_back(emb_norm);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue