llama : fix embeddings (#5796)

* llama : fix embeddings

ggml-ci

* llama : do not use KV cache for non-causal models

ggml-ci

* embeddings : fix llama_batch_init arg

* llama : add pooling switch

* llama : distinguish token vs sequence embeddings

ggml-ci

* llama : assert pooling tensor

* llama : simplify causal mask condition

ggml-ci

* llama : assert input batch with pooling enabled

* readme : update API changes list
This commit is contained in:
Georgi Gerganov 2024-03-04 22:31:20 +02:00 committed by GitHub
parent e0843afe1b
commit 29ae62d2ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 359 additions and 134 deletions

18
llama.h
View file

@ -163,7 +163,7 @@ extern "C" {
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence
// - seq_id : the sequence to which the respective token belongs
// - logits : if zero, the logits for the respective token will not be output
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
//
typedef struct llama_batch {
int32_t n_tokens;
@ -173,7 +173,7 @@ extern "C" {
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits;
int8_t * logits; // TODO: rename this to "output"
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
@ -260,7 +260,7 @@ extern "C" {
// Keep the booleans together to avoid misalignment during copy-by-value.
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
// Abort callback
@ -655,14 +655,20 @@ extern "C" {
// llama_get_logits(ctx) + i*n_vocab
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
// Get the embeddings for the input
// shape: [n_embd] (1-dimensional)
// Get all output token embeddings
// shape: [n_tokens*n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Get the embeddings for the ith sequence
// Get the embeddings for the ith token
// llama_get_embeddings(ctx) + i*n_embd
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
// Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
//
// Vocab
//