From 798c29d6b9fc700b62432ba65d1ed52dbcdd3164 Mon Sep 17 00:00:00 2001 From: Matt Grosso Date: Wed, 17 Apr 2024 17:49:29 -0700 Subject: [PATCH] gritlm example using llama_get_embeddings_mean_pooled --- examples/gritlm/gritlm.cpp | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 52fd719b3..007579f61 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -54,24 +54,11 @@ static std::vector> encode(llama_context * ctx, const std::ve // allocate embedding output std::vector emb_unorm(n_embd, 0.0f); + float * raw_embed = emb_unorm.data(); - // sum up all token embeddings - for (int32_t k = n_inst; k < n_toks; k++) { - float * emb = llama_get_embeddings_ith(ctx, k); - for (uint64_t j = 0; j < n_embd; j++) { - emb_unorm[j] += emb[j]; - } - } - - // divide by number of tokens (mean pooling) - { - const uint64_t n_sent = n_toks - n_inst; - - for (uint64_t j = 0; j < n_embd; j++) { - emb_unorm[j] /= n_sent; - } - } - + // retrieve summed up token embeddings, skipping instruction tokens, + // and writes the result to raw_embed + llama_get_embeddings_mean_pooled(ctx, n_inst, n_toks, raw_embed); std::vector emb_norm(emb_unorm.size()); llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd); result.push_back(emb_norm);