llama : distinguish token vs sequence embeddings
ggml-ci
This commit is contained in:
parent
e66da356a4
commit
79e4eede23
5 changed files with 128 additions and 56 deletions
|
@ -23,7 +23,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
|||
}
|
||||
}
|
||||
|
||||
static void normalize(float * vec, float * out, int n) {
|
||||
static void normalize(const float * vec, float * out, int n) {
|
||||
float norm = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
norm += vec[i] * vec[i];
|
||||
|
@ -50,9 +50,18 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||
continue;
|
||||
}
|
||||
|
||||
float * emb = llama_get_embeddings_ith(ctx, i);
|
||||
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
if (embd == NULL) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
if (embd == NULL) {
|
||||
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||
normalize(emb, out, n_embd);
|
||||
normalize(embd, out, n_embd);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ async def main():
|
|||
model_url = "http://127.0.0.1:6900"
|
||||
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
|
||||
url= f"{model_url}/embedding",
|
||||
json= {"content": str(0)*32}
|
||||
json= {"content": str(i)*1024}
|
||||
) for i in range(n)])
|
||||
|
||||
for response in responses:
|
||||
|
|
|
@ -1235,12 +1235,22 @@ struct llama_server_context
|
|||
continue;
|
||||
}
|
||||
|
||||
const float * data = llama_get_embeddings_ith(ctx, i);
|
||||
std::vector<float> embedding(data, data + n_embd);
|
||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
if (embd == NULL) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
if (embd == NULL) {
|
||||
LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
|
||||
res.result_json = json
|
||||
{
|
||||
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
||||
};
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
res.result_json = json
|
||||
{
|
||||
{"embedding", embedding },
|
||||
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
111
llama.cpp
111
llama.cpp
|
@ -1983,7 +1983,12 @@ struct llama_context {
|
|||
bool logits_all = false;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_tokens][n_embd])
|
||||
std::vector<float> embeddings;
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
std::vector<float> embd;
|
||||
|
||||
// sequence embeddings output (map of [n_embd] vectors)
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||
|
||||
// memory buffers used to evaluate the model
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
|
@ -6243,12 +6248,23 @@ struct llm_build_context {
|
|||
cur = inpL;
|
||||
|
||||
// pooling layer
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
||||
switch (pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// nop
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
|
||||
} else if (pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
{
|
||||
cur = ggml_get_rows(ctx0, cur, inp_cls);
|
||||
} else {
|
||||
GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type");
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ASSERT(false && "Invalid pooling type");
|
||||
} break;
|
||||
}
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
|
@ -8259,17 +8275,23 @@ static int llama_decode_internal(
|
|||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
|
||||
|
||||
if (!hparams.causal_attn) {
|
||||
res = nullptr; // do not extract logits for embedding models such as BERT
|
||||
|
||||
// token or sequence embeddings
|
||||
embd = gf->nodes[gf->n_nodes - 1];
|
||||
|
||||
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0);
|
||||
} else {
|
||||
if (strcmp(res->name, "result_output") == 0) {
|
||||
// the embeddings could be the second to last tensor, or the third to last tensor
|
||||
// the token embeddings could be the second to last tensor, or the third to last tensor
|
||||
if (strcmp(embd->name, "result_norm") != 0) {
|
||||
embd = gf->nodes[gf->n_nodes - 3];
|
||||
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
|
||||
}
|
||||
} else if (strcmp(res->name, "result_embd") == 0) {
|
||||
embd = res;
|
||||
res = nullptr;
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
GGML_ASSERT(false && "missing result_output tensor");
|
||||
}
|
||||
}
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
@ -8368,30 +8390,46 @@ static int llama_decode_internal(
|
|||
|
||||
// extract embeddings
|
||||
if (cparams.embeddings && embd) {
|
||||
auto & embeddings_out = lctx.embeddings;
|
||||
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
auto & embd_out = lctx.embd;
|
||||
|
||||
if (batch.logits) {
|
||||
embeddings_out.resize(n_embd * n_tokens);
|
||||
embd_out.resize(n_embd * n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
switch (hparams.pooling_type) {
|
||||
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
|
||||
break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
||||
break;
|
||||
default:
|
||||
{
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
embd_seq_out.clear();
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ASSERT(false && "unknown pooling type");
|
||||
break;
|
||||
}
|
||||
}
|
||||
} break;
|
||||
}
|
||||
ggml_backend_synchronize(backend_embd);
|
||||
}
|
||||
|
@ -12273,7 +12311,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
|
||||
|
||||
if (params.embeddings) {
|
||||
ctx->embeddings.reserve(hparams.n_embd*cparams.n_batch);
|
||||
ctx->embd.reserve(hparams.n_embd*cparams.n_batch);
|
||||
}
|
||||
|
||||
// graph inputs
|
||||
|
@ -12708,7 +12746,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
|||
// assume worst case for logits although only currently set ones are serialized
|
||||
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
|
||||
const size_t s_embedding_size = sizeof(size_t);
|
||||
const size_t s_embedding = ctx->embeddings.capacity() * sizeof(float);
|
||||
const size_t s_embedding = ctx->embd.capacity() * sizeof(float);
|
||||
const size_t s_kv_buf_size = sizeof(size_t);
|
||||
const size_t s_kv_head = sizeof(uint32_t);
|
||||
const size_t s_kv_size = sizeof(uint32_t);
|
||||
|
@ -12817,12 +12855,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||
|
||||
// copy embeddings
|
||||
{
|
||||
const size_t embeddings_size = ctx->embeddings.size();
|
||||
const size_t embeddings_size = ctx->embd.size();
|
||||
|
||||
data_ctx->write(&embeddings_size, sizeof(embeddings_size));
|
||||
|
||||
if (embeddings_size) {
|
||||
data_ctx->write(ctx->embeddings.data(), embeddings_size * sizeof(float));
|
||||
data_ctx->write(ctx->embd.data(), embeddings_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -12930,12 +12968,12 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|||
|
||||
memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
|
||||
|
||||
GGML_ASSERT(ctx->embeddings.capacity() == embeddings_size);
|
||||
GGML_ASSERT(ctx->embd.capacity() == embeddings_size);
|
||||
|
||||
if (embeddings_size) {
|
||||
ctx->embeddings.resize(embeddings_size);
|
||||
ctx->embd.resize(embeddings_size);
|
||||
|
||||
memcpy(ctx->embeddings.data(), inp, embeddings_size * sizeof(float));
|
||||
memcpy(ctx->embd.data(), inp, embeddings_size * sizeof(float));
|
||||
inp += embeddings_size * sizeof(float);
|
||||
}
|
||||
}
|
||||
|
@ -13186,11 +13224,20 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
|||
}
|
||||
|
||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||
return ctx->embeddings.data();
|
||||
return ctx->embd.data();
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||
return ctx->embeddings.data() + i*ctx->model.hparams.n_embd;
|
||||
return ctx->embd.data() + i*ctx->model.hparams.n_embd;
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
auto it = ctx->embd_seq.find(seq_id);
|
||||
if (it == ctx->embd_seq.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return it->second.data();
|
||||
}
|
||||
|
||||
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
|
||||
|
|
10
llama.h
10
llama.h
|
@ -655,14 +655,20 @@ extern "C" {
|
|||
// llama_get_logits(ctx) + i*n_vocab
|
||||
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the embeddings for the input
|
||||
// shape: [n_embd] (1-dimensional)
|
||||
// Get all output token embeddings
|
||||
// shape: [n_tokens*n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
|
||||
// Get the embeddings for the ith token
|
||||
// llama_get_embeddings(ctx) + i*n_embd
|
||||
// shape: [n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the embeddings for a sequence id
|
||||
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
||||
// shape: [n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
//
|
||||
// Vocab
|
||||
//
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue