llama : fix embeddings

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-02-29 15:39:10 +02:00
parent a0fc62661f
commit d0347840c1
No known key found for this signature in database
GPG key ID: BF970631944C16B7
6 changed files with 127 additions and 62 deletions

View file

@ -163,7 +163,7 @@ extern "C" {
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence
// - seq_id : the sequence to which the respective token belongs
// - logits : if zero, the logits for the respective token will not be output
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
//
typedef struct llama_batch {
int32_t n_tokens;
@ -173,7 +173,7 @@ extern "C" {
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits;
int8_t * logits; // TODO: rename this to "output"
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
@ -260,7 +260,7 @@ extern "C" {
// Keep the booleans together to avoid misalignment during copy-by-value.
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
// Abort callback
@ -659,7 +659,7 @@ extern "C" {
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Get the embeddings for the ith sequence
// Get the embeddings for the ith token
// llama_get_embeddings(ctx) + i*n_embd
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);