llama_get_embeddings_mean_pooled

This commit is contained in:
Matt Grosso 2024-04-17 17:47:25 -07:00
parent 3b8f1ec4b1
commit 2a24f71497
2 changed files with 51 additions and 0 deletions

View file

@ -16609,6 +16609,28 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id
return it->second.data();
}
void llama_get_embeddings_mean_pooled(struct llama_context * ctx, int32_t skip_tokens, int32_t batch_tokens, float *dest) {
GGML_ASSERT(dest);
GGML_ASSERT(batch_tokens > 0);
GGML_ASSERT(skip_tokens >= 0);
GGML_ASSERT(skip_tokens < batch_tokens);
float inv_tokens_to_pool = 1.0f / (batch_tokens - skip_tokens);
GGML_ASSERT(inv_tokens_to_pool > 0.0f);
GGML_ASSERT(inv_tokens_to_pool <= 1.0f);
float * all_token_embedddings = ctx->embd.data();
const llama_model * mdl = llama_get_model(ctx);
int32_t n_embd = llama_n_embd(mdl); // length of each embedding
for (int32_t i = skip_tokens; i < batch_tokens; i++) {
float * token_embedding = all_token_embedddings + i * n_embd;
for (int32_t j = 0; j < n_embd; j++) {
dest[j] += token_embedding[j];
}
}
for (int32_t i = 0; i < n_embd; i++) {
dest[i] *= inv_tokens_to_pool;
}
}
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
return model->vocab.id_to_token[token].text.c_str();

29
llama.h
View file

@ -773,6 +773,35 @@ extern "C" {
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
// Get the mean pooled embeddings for a subset of the tokens from the encoding.
//
// They will not be normalized. see llama_embd_normalize for that.
//
// The mean pooling here is done outside of the device and thus will work
// for model types that currently refuse to build a pooling layer in the
// device. Currently two large llama embedding models, GritLM and
// e5-mistral are supported; notably both of those are initialized via
// build_llama which won't have a pooling layer, inp_mean. Both models rely
// on prompts that are tokenized and which contribute to the attention but
// which may or may not be included in the mean pooling, depending on the
// application.
//
// TODO: 1. support inp_mean in llama models when mean pooling is specified
// so we can have the man calculated on the device and
// TODO: 2. also have the context own the destination pooled embedding
// memory to be more consistent with other apis, but also continue to
// allow application control over skipping some tokens.
//
// skip_tokens: The number of tokens to skip from the beginning of the batch tokens
// batch_tokens: The number of tokens in the batch
// dest: The destination array to store the mean pooled embeddings
//
// 'dest' array pointer must have the same length as the embeddings
// 'batch_tokens' - 'skip_tokens' is the number of tokens to pool
// [skip_tokens, batch_tokens) is the range of tokens to pool
//
LLAMA_API void llama_get_embeddings_mean_pooled(struct llama_context * ctx, int32_t skip_tokens, int32_t batch_tokens, float *dest);
//
// Vocab
//