llama : distinguish token vs sequence embeddings

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-03-04 19:14:22 +02:00
parent e66da356a4
commit 79e4eede23
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 128 additions and 56 deletions

View file

@ -23,7 +23,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
} }
} }
static void normalize(float * vec, float * out, int n) { static void normalize(const float * vec, float * out, int n) {
float norm = 0; float norm = 0;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
norm += vec[i] * vec[i]; norm += vec[i] * vec[i];
@ -50,9 +50,18 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
continue; continue;
} }
float * emb = llama_get_embeddings_ith(ctx, i); // try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
}
float * out = output + batch.seq_id[i][0] * n_embd; float * out = output + batch.seq_id[i][0] * n_embd;
normalize(emb, out, n_embd); normalize(embd, out, n_embd);
} }
} }

View file

@ -13,7 +13,7 @@ async def main():
model_url = "http://127.0.0.1:6900" model_url = "http://127.0.0.1:6900"
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async( responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
url= f"{model_url}/embedding", url= f"{model_url}/embedding",
json= {"content": str(0)*32} json= {"content": str(i)*1024}
) for i in range(n)]) ) for i in range(n)])
for response in responses: for response in responses:

View file

@ -1235,12 +1235,22 @@ struct llama_server_context
continue; continue;
} }
const float * data = llama_get_embeddings_ith(ctx, i); const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
std::vector<float> embedding(data, data + n_embd); if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
res.result_json = json
{
{"embedding", std::vector<float>(n_embd, 0.0f)},
};
continue;
}
}
res.result_json = json res.result_json = json
{ {
{"embedding", embedding }, {"embedding", std::vector<float>(embd, embd + n_embd)},
}; };
} }
} }

111
llama.cpp
View file

@ -1983,7 +1983,12 @@ struct llama_context {
bool logits_all = false; bool logits_all = false;
// embeddings output (2-dimensional array: [n_tokens][n_embd]) // embeddings output (2-dimensional array: [n_tokens][n_embd])
std::vector<float> embeddings; // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
std::vector<float> embd;
// sequence embeddings output (map of [n_embd] vectors)
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;
// memory buffers used to evaluate the model // memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta; std::vector<uint8_t> buf_compute_meta;
@ -6243,12 +6248,23 @@ struct llm_build_context {
cur = inpL; cur = inpL;
// pooling layer // pooling layer
if (pooling_type == LLAMA_POOLING_TYPE_MEAN) { switch (pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// nop
} break;
case LLAMA_POOLING_TYPE_MEAN:
{
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
} else if (pooling_type == LLAMA_POOLING_TYPE_CLS) { } break;
case LLAMA_POOLING_TYPE_CLS:
{
cur = ggml_get_rows(ctx0, cur, inp_cls); cur = ggml_get_rows(ctx0, cur, inp_cls);
} else { } break;
GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type"); case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ASSERT(false && "Invalid pooling type");
} break;
} }
cb(cur, "result_embd", -1); cb(cur, "result_embd", -1);
@ -8259,17 +8275,23 @@ static int llama_decode_internal(
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2]; struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
if (!hparams.causal_attn) {
res = nullptr; // do not extract logits for embedding models such as BERT
// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0);
} else {
if (strcmp(res->name, "result_output") == 0) { if (strcmp(res->name, "result_output") == 0) {
// the embeddings could be the second to last tensor, or the third to last tensor // the token embeddings could be the second to last tensor, or the third to last tensor
if (strcmp(embd->name, "result_norm") != 0) { if (strcmp(embd->name, "result_norm") != 0) {
embd = gf->nodes[gf->n_nodes - 3]; embd = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
} }
} else if (strcmp(res->name, "result_embd") == 0) {
embd = res;
res = nullptr;
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false && "missing result_output tensor");
}
} }
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
@ -8368,30 +8390,46 @@ static int llama_decode_internal(
// extract embeddings // extract embeddings
if (cparams.embeddings && embd) { if (cparams.embeddings && embd) {
auto & embeddings_out = lctx.embeddings;
ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd); ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr); GGML_ASSERT(backend_embd != nullptr);
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
auto & embd_out = lctx.embd;
if (batch.logits) { if (batch.logits) {
embeddings_out.resize(n_embd * n_tokens); embd_out.resize(n_embd * n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) { for (uint32_t i = 0; i < n_tokens; i++) {
if (batch.logits[i] == 0) { if (batch.logits[i] == 0) {
continue; continue;
} }
switch (hparams.pooling_type) {
ggml_backend_tensor_get_async(backend_embd, embd, embd_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
}
}
} break;
case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_CLS:
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
break;
case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_NONE: {
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float)); // extract sequence embeddings
break; auto & embd_seq_out = lctx.embd_seq;
default: embd_seq_out.clear();
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = batch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ASSERT(false && "unknown pooling type"); GGML_ASSERT(false && "unknown pooling type");
break; } break;
}
}
} }
ggml_backend_synchronize(backend_embd); ggml_backend_synchronize(backend_embd);
} }
@ -12273,7 +12311,7 @@ struct llama_context * llama_new_context_with_model(
ctx->logits.reserve(hparams.n_vocab*cparams.n_batch); ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
if (params.embeddings) { if (params.embeddings) {
ctx->embeddings.reserve(hparams.n_embd*cparams.n_batch); ctx->embd.reserve(hparams.n_embd*cparams.n_batch);
} }
// graph inputs // graph inputs
@ -12708,7 +12746,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
// assume worst case for logits although only currently set ones are serialized // assume worst case for logits although only currently set ones are serialized
const size_t s_logits = ctx->logits.capacity() * sizeof(float); const size_t s_logits = ctx->logits.capacity() * sizeof(float);
const size_t s_embedding_size = sizeof(size_t); const size_t s_embedding_size = sizeof(size_t);
const size_t s_embedding = ctx->embeddings.capacity() * sizeof(float); const size_t s_embedding = ctx->embd.capacity() * sizeof(float);
const size_t s_kv_buf_size = sizeof(size_t); const size_t s_kv_buf_size = sizeof(size_t);
const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_head = sizeof(uint32_t);
const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t);
@ -12817,12 +12855,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
// copy embeddings // copy embeddings
{ {
const size_t embeddings_size = ctx->embeddings.size(); const size_t embeddings_size = ctx->embd.size();
data_ctx->write(&embeddings_size, sizeof(embeddings_size)); data_ctx->write(&embeddings_size, sizeof(embeddings_size));
if (embeddings_size) { if (embeddings_size) {
data_ctx->write(ctx->embeddings.data(), embeddings_size * sizeof(float)); data_ctx->write(ctx->embd.data(), embeddings_size * sizeof(float));
} }
} }
@ -12930,12 +12968,12 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size); memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
GGML_ASSERT(ctx->embeddings.capacity() == embeddings_size); GGML_ASSERT(ctx->embd.capacity() == embeddings_size);
if (embeddings_size) { if (embeddings_size) {
ctx->embeddings.resize(embeddings_size); ctx->embd.resize(embeddings_size);
memcpy(ctx->embeddings.data(), inp, embeddings_size * sizeof(float)); memcpy(ctx->embd.data(), inp, embeddings_size * sizeof(float));
inp += embeddings_size * sizeof(float); inp += embeddings_size * sizeof(float);
} }
} }
@ -13186,11 +13224,20 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
} }
float * llama_get_embeddings(struct llama_context * ctx) { float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embeddings.data(); return ctx->embd.data();
} }
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
return ctx->embeddings.data() + i*ctx->model.hparams.n_embd; return ctx->embd.data() + i*ctx->model.hparams.n_embd;
}
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
auto it = ctx->embd_seq.find(seq_id);
if (it == ctx->embd_seq.end()) {
return nullptr;
}
return it->second.data();
} }
const char * llama_token_get_text(const struct llama_model * model, llama_token token) { const char * llama_token_get_text(const struct llama_model * model, llama_token token) {

10
llama.h
View file

@ -655,14 +655,20 @@ extern "C" {
// llama_get_logits(ctx) + i*n_vocab // llama_get_logits(ctx) + i*n_vocab
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
// Get the embeddings for the input // Get all output token embeddings
// shape: [n_embd] (1-dimensional) // shape: [n_tokens*n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Get the embeddings for the ith token // Get the embeddings for the ith token
// llama_get_embeddings(ctx) + i*n_embd // llama_get_embeddings(ctx) + i*n_embd
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
// Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
// //
// Vocab // Vocab
// //