rebase, handle sometimes smaller embd & new type

0. initialize dest with zeros
1. embd is not a vector anymore
2. embd size from embd_size may be smaller than batch_tokens because
    it honors logits array, so we use ctx->n_outputs to bound our
    embd outer loop.
3. Remove the batch_tokens foot-gun parameter since we have
   authoritative information on the size of the embedding outputs
   from the context.
4. improve comment docs
5. incorporate new usage for gritlm example
This commit is contained in:
Matt Grosso 2024-04-18 17:00:46 -07:00
parent 798c29d6b9
commit c0c95edc89
3 changed files with 31 additions and 36 deletions

View file

@ -57,8 +57,10 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
float * raw_embed = emb_unorm.data();
// retrieve summed up token embeddings, skipping instruction tokens,
// and writes the result to raw_embed
llama_get_embeddings_mean_pooled(ctx, n_inst, n_toks, raw_embed);
// and writes the result to raw_embed. We pass 0 for number of
// instructions to skip as we have already marked logits = false in the
// llama_batch_add loop for instruction tokens.
llama_get_embeddings_mean_pooled(ctx, 0, raw_embed);
std::vector<float> emb_norm(emb_unorm.size());
llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
result.push_back(emb_norm);

View file

@ -16609,19 +16609,24 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id
return it->second.data();
}
void llama_get_embeddings_mean_pooled(struct llama_context * ctx, int32_t skip_tokens, int32_t batch_tokens, float *dest) {
void llama_get_embeddings_mean_pooled(struct llama_context * ctx, int32_t skip_tokens, float *dest) {
GGML_ASSERT(dest);
GGML_ASSERT(batch_tokens > 0);
GGML_ASSERT(skip_tokens >= 0);
int32_t n_embd = ctx->model.hparams.n_embd;
int32_t batch_tokens = ctx->n_outputs;
GGML_ASSERT(skip_tokens < batch_tokens);
GGML_ASSERT(ctx->embd_size >= (size_t)(batch_tokens * n_embd));
float inv_tokens_to_pool = 1.0f / (batch_tokens - skip_tokens);
GGML_ASSERT(inv_tokens_to_pool > 0.0f);
GGML_ASSERT(inv_tokens_to_pool <= 1.0f);
float * all_token_embedddings = ctx->embd.data();
const llama_model * mdl = llama_get_model(ctx);
int32_t n_embd = llama_n_embd(mdl); // length of each embedding
for (int32_t i = 0; i < n_embd; i++) {
dest[i] = 0.0f;
}
for (int32_t i = skip_tokens; i < batch_tokens; i++) {
float * token_embedding = all_token_embedddings + i * n_embd;
float * token_embedding = ctx->embd + i * n_embd;
for (int32_t j = 0; j < n_embd; j++) {
dest[j] += token_embedding[j];
}

44
llama.h
View file

@ -773,34 +773,22 @@ extern "C" {
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
// Get the mean pooled embeddings for a subset of the tokens from the encoding.
//
// They will not be normalized. see llama_embd_normalize for that.
//
// The mean pooling here is done outside of the device and thus will work
// for model types that currently refuse to build a pooling layer in the
// device. Currently two large llama embedding models, GritLM and
// e5-mistral are supported; notably both of those are initialized via
// build_llama which won't have a pooling layer, inp_mean. Both models rely
// on prompts that are tokenized and which contribute to the attention but
// which may or may not be included in the mean pooling, depending on the
// application.
//
// TODO: 1. support inp_mean in llama models when mean pooling is specified
// so we can have the man calculated on the device and
// TODO: 2. also have the context own the destination pooled embedding
// memory to be more consistent with other apis, but also continue to
// allow application control over skipping some tokens.
//
// skip_tokens: The number of tokens to skip from the beginning of the batch tokens
// batch_tokens: The number of tokens in the batch
// dest: The destination array to store the mean pooled embeddings
//
// 'dest' array pointer must have the same length as the embeddings
// 'batch_tokens' - 'skip_tokens' is the number of tokens to pool
// [skip_tokens, batch_tokens) is the range of tokens to pool
//
LLAMA_API void llama_get_embeddings_mean_pooled(struct llama_context * ctx, int32_t skip_tokens, int32_t batch_tokens, float *dest);
/// @details: Get the mean pooled embeddings for a subset of the tokens from the encoding.
/// @param ctx Pointer to the llama_context.
/// @param skip_tokens The number of tokens to skip from the beginning of the embeddings array of arrays.
/// @param dest will store the mean pooled embeddings, so it must point to sizeof(float) * llama_n_embd(model) bytes
///
/// If you used llama_batch_get_one and have instructions to skip from the
/// embedding, or used llama_batch_add but wish to skip some tokens from the
/// beginning even though they have the 'logits' boolean set to true, then
/// set skip_tokens to a non-zero value.
///
/// Results will not be normalized. Pass the dest of this as src to llama_embd_normalize for that.
///
/// The mean pooling here is done outside of the device and thus will work
/// for model types that currently don't use the inp_mean pooling layer in the
/// device. See the examples/gritlm/gritlm.cpp for an example of how to use this.
LLAMA_API void llama_get_embeddings_mean_pooled(struct llama_context * ctx, int32_t skip_tokens, float *dest);
//
// Vocab