llama : distinguish token vs sequence embeddings
ggml-ci
This commit is contained in:
parent
e66da356a4
commit
79e4eede23
5 changed files with 128 additions and 56 deletions
|
@ -23,7 +23,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void normalize(float * vec, float * out, int n) {
|
static void normalize(const float * vec, float * out, int n) {
|
||||||
float norm = 0;
|
float norm = 0;
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
norm += vec[i] * vec[i];
|
norm += vec[i] * vec[i];
|
||||||
|
@ -50,9 +50,18 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * emb = llama_get_embeddings_ith(ctx, i);
|
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
||||||
|
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||||
|
if (embd == NULL) {
|
||||||
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
|
if (embd == NULL) {
|
||||||
|
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||||
normalize(emb, out, n_embd);
|
normalize(embd, out, n_embd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ async def main():
|
||||||
model_url = "http://127.0.0.1:6900"
|
model_url = "http://127.0.0.1:6900"
|
||||||
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
|
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
|
||||||
url= f"{model_url}/embedding",
|
url= f"{model_url}/embedding",
|
||||||
json= {"content": str(0)*32}
|
json= {"content": str(i)*1024}
|
||||||
) for i in range(n)])
|
) for i in range(n)])
|
||||||
|
|
||||||
for response in responses:
|
for response in responses:
|
||||||
|
|
|
@ -1235,12 +1235,22 @@ struct llama_server_context
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float * data = llama_get_embeddings_ith(ctx, i);
|
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||||
std::vector<float> embedding(data, data + n_embd);
|
if (embd == NULL) {
|
||||||
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
|
if (embd == NULL) {
|
||||||
|
LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
|
||||||
|
res.result_json = json
|
||||||
|
{
|
||||||
|
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
||||||
|
};
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
res.result_json = json
|
res.result_json = json
|
||||||
{
|
{
|
||||||
{"embedding", embedding },
|
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
141
llama.cpp
141
llama.cpp
|
@ -1983,7 +1983,12 @@ struct llama_context {
|
||||||
bool logits_all = false;
|
bool logits_all = false;
|
||||||
|
|
||||||
// embeddings output (2-dimensional array: [n_tokens][n_embd])
|
// embeddings output (2-dimensional array: [n_tokens][n_embd])
|
||||||
std::vector<float> embeddings;
|
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||||
|
std::vector<float> embd;
|
||||||
|
|
||||||
|
// sequence embeddings output (map of [n_embd] vectors)
|
||||||
|
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||||
|
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||||
|
|
||||||
// memory buffers used to evaluate the model
|
// memory buffers used to evaluate the model
|
||||||
std::vector<uint8_t> buf_compute_meta;
|
std::vector<uint8_t> buf_compute_meta;
|
||||||
|
@ -6243,12 +6248,23 @@ struct llm_build_context {
|
||||||
cur = inpL;
|
cur = inpL;
|
||||||
|
|
||||||
// pooling layer
|
// pooling layer
|
||||||
if (pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
switch (pooling_type) {
|
||||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
|
case LLAMA_POOLING_TYPE_NONE:
|
||||||
} else if (pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
{
|
||||||
cur = ggml_get_rows(ctx0, cur, inp_cls);
|
// nop
|
||||||
} else {
|
} break;
|
||||||
GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type");
|
case LLAMA_POOLING_TYPE_MEAN:
|
||||||
|
{
|
||||||
|
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
|
||||||
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_CLS:
|
||||||
|
{
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_cls);
|
||||||
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false && "Invalid pooling type");
|
||||||
|
} break;
|
||||||
}
|
}
|
||||||
cb(cur, "result_embd", -1);
|
cb(cur, "result_embd", -1);
|
||||||
|
|
||||||
|
@ -8259,17 +8275,23 @@ static int llama_decode_internal(
|
||||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||||
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
|
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
|
||||||
|
|
||||||
if (strcmp(res->name, "result_output") == 0) {
|
if (!hparams.causal_attn) {
|
||||||
// the embeddings could be the second to last tensor, or the third to last tensor
|
res = nullptr; // do not extract logits for embedding models such as BERT
|
||||||
if (strcmp(embd->name, "result_norm") != 0) {
|
|
||||||
embd = gf->nodes[gf->n_nodes - 3];
|
// token or sequence embeddings
|
||||||
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
|
embd = gf->nodes[gf->n_nodes - 1];
|
||||||
}
|
|
||||||
} else if (strcmp(res->name, "result_embd") == 0) {
|
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0);
|
||||||
embd = res;
|
|
||||||
res = nullptr;
|
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false);
|
if (strcmp(res->name, "result_output") == 0) {
|
||||||
|
// the token embeddings could be the second to last tensor, or the third to last tensor
|
||||||
|
if (strcmp(embd->name, "result_norm") != 0) {
|
||||||
|
embd = gf->nodes[gf->n_nodes - 3];
|
||||||
|
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false && "missing result_output tensor");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||||
|
@ -8368,30 +8390,46 @@ static int llama_decode_internal(
|
||||||
|
|
||||||
// extract embeddings
|
// extract embeddings
|
||||||
if (cparams.embeddings && embd) {
|
if (cparams.embeddings && embd) {
|
||||||
auto & embeddings_out = lctx.embeddings;
|
|
||||||
|
|
||||||
ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd);
|
ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd);
|
||||||
GGML_ASSERT(backend_embd != nullptr);
|
GGML_ASSERT(backend_embd != nullptr);
|
||||||
|
|
||||||
if (batch.logits) {
|
switch (cparams.pooling_type) {
|
||||||
embeddings_out.resize(n_embd * n_tokens);
|
case LLAMA_POOLING_TYPE_NONE:
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
{
|
||||||
if (batch.logits[i] == 0) {
|
// extract token embeddings
|
||||||
continue;
|
auto & embd_out = lctx.embd;
|
||||||
}
|
|
||||||
switch (hparams.pooling_type) {
|
if (batch.logits) {
|
||||||
case LLAMA_POOLING_TYPE_CLS:
|
embd_out.resize(n_embd * n_tokens);
|
||||||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
break;
|
if (batch.logits[i] == 0) {
|
||||||
case LLAMA_POOLING_TYPE_MEAN:
|
continue;
|
||||||
case LLAMA_POOLING_TYPE_NONE:
|
}
|
||||||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
|
||||||
break;
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
||||||
default:
|
}
|
||||||
GGML_ASSERT(false && "unknown pooling type");
|
}
|
||||||
break;
|
} break;
|
||||||
}
|
case LLAMA_POOLING_TYPE_CLS:
|
||||||
}
|
case LLAMA_POOLING_TYPE_MEAN:
|
||||||
|
{
|
||||||
|
// extract sequence embeddings
|
||||||
|
auto & embd_seq_out = lctx.embd_seq;
|
||||||
|
embd_seq_out.clear();
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
|
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||||
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
embd_seq_out[seq_id].resize(n_embd);
|
||||||
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false && "unknown pooling type");
|
||||||
|
} break;
|
||||||
}
|
}
|
||||||
ggml_backend_synchronize(backend_embd);
|
ggml_backend_synchronize(backend_embd);
|
||||||
}
|
}
|
||||||
|
@ -12273,7 +12311,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
|
ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
|
||||||
|
|
||||||
if (params.embeddings) {
|
if (params.embeddings) {
|
||||||
ctx->embeddings.reserve(hparams.n_embd*cparams.n_batch);
|
ctx->embd.reserve(hparams.n_embd*cparams.n_batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
// graph inputs
|
// graph inputs
|
||||||
|
@ -12708,7 +12746,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
||||||
// assume worst case for logits although only currently set ones are serialized
|
// assume worst case for logits although only currently set ones are serialized
|
||||||
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
|
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
|
||||||
const size_t s_embedding_size = sizeof(size_t);
|
const size_t s_embedding_size = sizeof(size_t);
|
||||||
const size_t s_embedding = ctx->embeddings.capacity() * sizeof(float);
|
const size_t s_embedding = ctx->embd.capacity() * sizeof(float);
|
||||||
const size_t s_kv_buf_size = sizeof(size_t);
|
const size_t s_kv_buf_size = sizeof(size_t);
|
||||||
const size_t s_kv_head = sizeof(uint32_t);
|
const size_t s_kv_head = sizeof(uint32_t);
|
||||||
const size_t s_kv_size = sizeof(uint32_t);
|
const size_t s_kv_size = sizeof(uint32_t);
|
||||||
|
@ -12817,12 +12855,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
||||||
|
|
||||||
// copy embeddings
|
// copy embeddings
|
||||||
{
|
{
|
||||||
const size_t embeddings_size = ctx->embeddings.size();
|
const size_t embeddings_size = ctx->embd.size();
|
||||||
|
|
||||||
data_ctx->write(&embeddings_size, sizeof(embeddings_size));
|
data_ctx->write(&embeddings_size, sizeof(embeddings_size));
|
||||||
|
|
||||||
if (embeddings_size) {
|
if (embeddings_size) {
|
||||||
data_ctx->write(ctx->embeddings.data(), embeddings_size * sizeof(float));
|
data_ctx->write(ctx->embd.data(), embeddings_size * sizeof(float));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12930,12 +12968,12 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
|
|
||||||
memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
|
memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
|
||||||
|
|
||||||
GGML_ASSERT(ctx->embeddings.capacity() == embeddings_size);
|
GGML_ASSERT(ctx->embd.capacity() == embeddings_size);
|
||||||
|
|
||||||
if (embeddings_size) {
|
if (embeddings_size) {
|
||||||
ctx->embeddings.resize(embeddings_size);
|
ctx->embd.resize(embeddings_size);
|
||||||
|
|
||||||
memcpy(ctx->embeddings.data(), inp, embeddings_size * sizeof(float));
|
memcpy(ctx->embd.data(), inp, embeddings_size * sizeof(float));
|
||||||
inp += embeddings_size * sizeof(float);
|
inp += embeddings_size * sizeof(float);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13186,11 +13224,20 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||||
return ctx->embeddings.data();
|
return ctx->embd.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||||
return ctx->embeddings.data() + i*ctx->model.hparams.n_embd;
|
return ctx->embd.data() + i*ctx->model.hparams.n_embd;
|
||||||
|
}
|
||||||
|
|
||||||
|
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||||
|
auto it = ctx->embd_seq.find(seq_id);
|
||||||
|
if (it == ctx->embd_seq.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return it->second.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
|
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
|
||||||
|
|
10
llama.h
10
llama.h
|
@ -655,14 +655,20 @@ extern "C" {
|
||||||
// llama_get_logits(ctx) + i*n_vocab
|
// llama_get_logits(ctx) + i*n_vocab
|
||||||
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||||
|
|
||||||
// Get the embeddings for the input
|
// Get all output token embeddings
|
||||||
// shape: [n_embd] (1-dimensional)
|
// shape: [n_tokens*n_embd] (1-dimensional)
|
||||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||||
|
|
||||||
// Get the embeddings for the ith token
|
// Get the embeddings for the ith token
|
||||||
// llama_get_embeddings(ctx) + i*n_embd
|
// llama_get_embeddings(ctx) + i*n_embd
|
||||||
|
// shape: [n_embd] (1-dimensional)
|
||||||
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||||
|
|
||||||
|
// Get the embeddings for a sequence id
|
||||||
|
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
||||||
|
// shape: [n_embd] (1-dimensional)
|
||||||
|
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Vocab
|
// Vocab
|
||||||
//
|
//
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue