llama : fix embeddings

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-02-29 15:39:10 +02:00
parent a0fc62661f
commit d0347840c1
No known key found for this signature in database
GPG key ID: BF970631944C16B7
6 changed files with 127 additions and 62 deletions

View file

@ -1210,7 +1210,7 @@ struct llama_server_context
queue_results.send(res);
}
void send_embedding(server_slot &slot)
void send_embedding(server_slot & slot, const llama_batch & batch)
{
task_result res;
res.id = slot.task_id;
@ -1219,6 +1219,7 @@ struct llama_server_context
res.stop = true;
const int n_embd = llama_n_embd(model);
if (!params.embedding)
{
LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
@ -1229,12 +1230,19 @@ struct llama_server_context
}
else
{
const float *data = llama_get_embeddings(ctx);
std::vector<float> embedding(data, data + n_embd);
res.result_json = json
{
{"embedding", embedding},
};
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}
const float * data = llama_get_embeddings_ith(ctx, i);
std::vector<float> embedding(data, data + n_embd);
res.result_json = json
{
{"embedding", embedding },
};
}
}
queue_results.send(res);
}
@ -1845,7 +1853,7 @@ struct llama_server_context
ga_i += ga_w/ga_n;
}
}
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
slot_npast++;
}
@ -1881,7 +1889,7 @@ struct llama_server_context
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
{
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
for (auto & slot : slots)
{
@ -1954,7 +1962,7 @@ struct llama_server_context
// prompt evaluated for embedding
if (slot.embedding)
{
send_embedding(slot);
send_embedding(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue;
@ -2330,7 +2338,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_batch = std::stoi(argv[i]);
params.n_batch = std::min(512, params.n_batch);
}
else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
{