From c0c95edc89ae15c5375b6d36153513ce71cf7793 Mon Sep 17 00:00:00 2001 From: Matt Grosso Date: Thu, 18 Apr 2024 17:00:46 -0700 Subject: [PATCH] rebase, handle sometimes smaller embd & new type 0. initialize dest with zeros 1. embd is not a vector anymore 2. embd size from embd_size may be smaller than batch_tokens because it honors logits array, so we use ctx->n_outputs to bound our embd outer loop. 3. Remove the batch_tokens foot-gun parameter since we have authoritative information on the size of the embedding outputs from the context. 4. improve comment docs 5. incorporate new usage for gritlm example --- examples/gritlm/gritlm.cpp | 6 ++++-- llama.cpp | 17 +++++++++------ llama.h | 44 ++++++++++++++------------------------ 3 files changed, 31 insertions(+), 36 deletions(-) diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 007579f61..74375e3f0 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -57,8 +57,10 @@ static std::vector> encode(llama_context * ctx, const std::ve float * raw_embed = emb_unorm.data(); // retrieve summed up token embeddings, skipping instruction tokens, - // and writes the result to raw_embed - llama_get_embeddings_mean_pooled(ctx, n_inst, n_toks, raw_embed); + // and writes the result to raw_embed. We pass 0 for number of + // instructions to skip as we have already marked logits = false in the + // llama_batch_add loop for instruction tokens. + llama_get_embeddings_mean_pooled(ctx, 0, raw_embed); std::vector emb_norm(emb_unorm.size()); llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd); result.push_back(emb_norm); diff --git a/llama.cpp b/llama.cpp index f0cad60ee..e69630e44 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16609,19 +16609,24 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id return it->second.data(); } -void llama_get_embeddings_mean_pooled(struct llama_context * ctx, int32_t skip_tokens, int32_t batch_tokens, float *dest) { +void llama_get_embeddings_mean_pooled(struct llama_context * ctx, int32_t skip_tokens, float *dest) { GGML_ASSERT(dest); - GGML_ASSERT(batch_tokens > 0); GGML_ASSERT(skip_tokens >= 0); + + int32_t n_embd = ctx->model.hparams.n_embd; + int32_t batch_tokens = ctx->n_outputs; GGML_ASSERT(skip_tokens < batch_tokens); + GGML_ASSERT(ctx->embd_size >= (size_t)(batch_tokens * n_embd)); + float inv_tokens_to_pool = 1.0f / (batch_tokens - skip_tokens); GGML_ASSERT(inv_tokens_to_pool > 0.0f); GGML_ASSERT(inv_tokens_to_pool <= 1.0f); - float * all_token_embedddings = ctx->embd.data(); - const llama_model * mdl = llama_get_model(ctx); - int32_t n_embd = llama_n_embd(mdl); // length of each embedding + + for (int32_t i = 0; i < n_embd; i++) { + dest[i] = 0.0f; + } for (int32_t i = skip_tokens; i < batch_tokens; i++) { - float * token_embedding = all_token_embedddings + i * n_embd; + float * token_embedding = ctx->embd + i * n_embd; for (int32_t j = 0; j < n_embd; j++) { dest[j] += token_embedding[j]; } diff --git a/llama.h b/llama.h index 2f2e31206..ab65ea6aa 100644 --- a/llama.h +++ b/llama.h @@ -773,34 +773,22 @@ extern "C" { // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); - // Get the mean pooled embeddings for a subset of the tokens from the encoding. - // - // They will not be normalized. see llama_embd_normalize for that. - // - // The mean pooling here is done outside of the device and thus will work - // for model types that currently refuse to build a pooling layer in the - // device. Currently two large llama embedding models, GritLM and - // e5-mistral are supported; notably both of those are initialized via - // build_llama which won't have a pooling layer, inp_mean. Both models rely - // on prompts that are tokenized and which contribute to the attention but - // which may or may not be included in the mean pooling, depending on the - // application. - // - // TODO: 1. support inp_mean in llama models when mean pooling is specified - // so we can have the man calculated on the device and - // TODO: 2. also have the context own the destination pooled embedding - // memory to be more consistent with other apis, but also continue to - // allow application control over skipping some tokens. - // - // skip_tokens: The number of tokens to skip from the beginning of the batch tokens - // batch_tokens: The number of tokens in the batch - // dest: The destination array to store the mean pooled embeddings - // - // 'dest' array pointer must have the same length as the embeddings - // 'batch_tokens' - 'skip_tokens' is the number of tokens to pool - // [skip_tokens, batch_tokens) is the range of tokens to pool - // - LLAMA_API void llama_get_embeddings_mean_pooled(struct llama_context * ctx, int32_t skip_tokens, int32_t batch_tokens, float *dest); + /// @details: Get the mean pooled embeddings for a subset of the tokens from the encoding. + /// @param ctx Pointer to the llama_context. + /// @param skip_tokens The number of tokens to skip from the beginning of the embeddings array of arrays. + /// @param dest will store the mean pooled embeddings, so it must point to sizeof(float) * llama_n_embd(model) bytes + /// + /// If you used llama_batch_get_one and have instructions to skip from the + /// embedding, or used llama_batch_add but wish to skip some tokens from the + /// beginning even though they have the 'logits' boolean set to true, then + /// set skip_tokens to a non-zero value. + /// + /// Results will not be normalized. Pass the dest of this as src to llama_embd_normalize for that. + /// + /// The mean pooling here is done outside of the device and thus will work + /// for model types that currently don't use the inp_mean pooling layer in the + /// device. See the examples/gritlm/gritlm.cpp for an example of how to use this. + LLAMA_API void llama_get_embeddings_mean_pooled(struct llama_context * ctx, int32_t skip_tokens, float *dest); // // Vocab