server : output embeddings for all tokens when pooling = none

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-12-17 10:56:20 +02:00
parent 44eeb6a88e
commit 07946a3a30
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 32 additions and 5 deletions

View file

@ -727,7 +727,7 @@ struct server_task_result_cmpl_partial : server_task_result {
struct server_task_result_embd : server_task_result { struct server_task_result_embd : server_task_result {
int index = 0; int index = 0;
std::vector<float> embedding; std::vector<std::vector<float>> embedding;
int32_t n_tokens; int32_t n_tokens;
@ -736,6 +736,14 @@ struct server_task_result_embd : server_task_result {
} }
virtual json to_json() override { virtual json to_json() override {
if (embedding.size() == 1){
// to be OAI compatible
return json {
{"index", index},
{"embedding", embedding[0]},
};
}
return json { return json {
{"index", index}, {"index", index},
{"embedding", embedding}, {"embedding", embedding},
@ -2040,12 +2048,12 @@ struct server_context {
if (embd == NULL) { if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res->embedding = std::vector<float>(n_embd, 0.0f); res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue; continue;
} }
common_embd_normalize(embd, embd_res.data(), n_embd); common_embd_normalize(embd, embd_res.data(), n_embd);
res->embedding = embd_res; res->embedding.push_back(embd_res);
} }
SLT_DBG(slot, "%s", "sending embeddings\n"); SLT_DBG(slot, "%s", "sending embeddings\n");
@ -2659,7 +2667,10 @@ struct server_context {
// add prompt tokens for processing in the current batch // add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false); // without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) { if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);

View file

@ -74,6 +74,18 @@ def test_embedding_mixed_input(content, is_multi_prompt: bool):
assert len(res.body['embedding']) > 1 assert len(res.body['embedding']) > 1
def test_embedding_pooling_none():
server = ServerPreset.bert_bge_small(pooling = 'none')
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": "hello hello hello",
})
assert res.status_code == 200
assert len(res.body['data']) == 1
assert 'embedding' in res.body['data'][0]
assert len(res.body['data'][0]['embedding']) == 3
def test_embedding_openai_library_single(): def test_embedding_openai_library_single():
global server global server
server.start() server.start()

View file

@ -65,6 +65,7 @@ class ServerProcess:
server_reranking: bool | None = False server_reranking: bool | None = False
server_metrics: bool | None = False server_metrics: bool | None = False
server_slots: bool | None = False server_slots: bool | None = False
pooling: str | None = None
draft: int | None = None draft: int | None = None
api_key: str | None = None api_key: str | None = None
response_format: str | None = None response_format: str | None = None
@ -132,6 +133,8 @@ class ServerProcess:
server_args.append("--metrics") server_args.append("--metrics")
if self.server_slots: if self.server_slots:
server_args.append("--slots") server_args.append("--slots")
if self.pooling:
server_args.extend(["--pooling", self.pooling])
if self.model_alias: if self.model_alias:
server_args.extend(["--alias", self.model_alias]) server_args.extend(["--alias", self.model_alias])
if self.n_ctx: if self.n_ctx:
@ -272,7 +275,7 @@ class ServerPreset:
return server return server
@staticmethod @staticmethod
def bert_bge_small() -> ServerProcess: def bert_bge_small(pooling = 'last') -> ServerProcess:
server = ServerProcess() server = ServerProcess()
server.model_hf_repo = "ggml-org/models" server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
@ -283,6 +286,7 @@ class ServerPreset:
server.n_slots = 2 server.n_slots = 2
server.seed = 42 server.seed = 42
server.server_embeddings = True server.server_embeddings = True
server.pooling = pooling
return server return server
@staticmethod @staticmethod