diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5b6d660b8..eac9ada5f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -727,7 +727,7 @@ struct server_task_result_cmpl_partial : server_task_result { struct server_task_result_embd : server_task_result { int index = 0; - std::vector embedding; + std::vector> embedding; int32_t n_tokens; @@ -736,6 +736,14 @@ struct server_task_result_embd : server_task_result { } virtual json to_json() override { + if (embedding.size() == 1){ + // to be OAI compatible + return json { + {"index", index}, + {"embedding", embedding[0]}, + }; + } + return json { {"index", index}, {"embedding", embedding}, @@ -2040,12 +2048,12 @@ struct server_context { if (embd == NULL) { SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - res->embedding = std::vector(n_embd, 0.0f); + res->embedding.push_back(std::vector(n_embd, 0.0f)); continue; } common_embd_normalize(embd, embd_res.data(), n_embd); - res->embedding = embd_res; + res->embedding.push_back(embd_res); } SLT_DBG(slot, "%s", "sending embeddings\n"); @@ -2659,7 +2667,10 @@ struct server_context { // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { - common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false); + // without pooling, we want to output the embeddings for all the tokens in the batch + const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; + + common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index 4f4e9dcf0..d6a3b6125 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -74,6 +74,18 @@ def test_embedding_mixed_input(content, is_multi_prompt: bool): assert len(res.body['embedding']) > 1 +def test_embedding_pooling_none(): + server = ServerPreset.bert_bge_small(pooling = 'none') + server.start() + res = server.make_request("POST", "/embeddings", data={ + "input": "hello hello hello", + }) + assert res.status_code == 200 + assert len(res.body['data']) == 1 + assert 'embedding' in res.body['data'][0] + assert len(res.body['data'][0]['embedding']) == 3 + + def test_embedding_openai_library_single(): global server server.start() diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index d988ccf5e..da95c830b 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -65,6 +65,7 @@ class ServerProcess: server_reranking: bool | None = False server_metrics: bool | None = False server_slots: bool | None = False + pooling: str | None = None draft: int | None = None api_key: str | None = None response_format: str | None = None @@ -132,6 +133,8 @@ class ServerProcess: server_args.append("--metrics") if self.server_slots: server_args.append("--slots") + if self.pooling: + server_args.extend(["--pooling", self.pooling]) if self.model_alias: server_args.extend(["--alias", self.model_alias]) if self.n_ctx: @@ -272,7 +275,7 @@ class ServerPreset: return server @staticmethod - def bert_bge_small() -> ServerProcess: + def bert_bge_small(pooling = 'last') -> ServerProcess: server = ServerProcess() server.model_hf_repo = "ggml-org/models" server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" @@ -283,6 +286,7 @@ class ServerPreset: server.n_slots = 2 server.seed = 42 server.server_embeddings = True + server.pooling = pooling return server @staticmethod