server : output embeddings for all tokens when pooling = none
ggml-ci
This commit is contained in:
parent
44eeb6a88e
commit
07946a3a30
3 changed files with 32 additions and 5 deletions
|
@ -727,7 +727,7 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
|
|
||||||
struct server_task_result_embd : server_task_result {
|
struct server_task_result_embd : server_task_result {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
std::vector<float> embedding;
|
std::vector<std::vector<float>> embedding;
|
||||||
|
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
|
|
||||||
|
@ -736,6 +736,14 @@ struct server_task_result_embd : server_task_result {
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
|
if (embedding.size() == 1){
|
||||||
|
// to be OAI compatible
|
||||||
|
return json {
|
||||||
|
{"index", index},
|
||||||
|
{"embedding", embedding[0]},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
return json {
|
return json {
|
||||||
{"index", index},
|
{"index", index},
|
||||||
{"embedding", embedding},
|
{"embedding", embedding},
|
||||||
|
@ -2040,12 +2048,12 @@ struct server_context {
|
||||||
if (embd == NULL) {
|
if (embd == NULL) {
|
||||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
||||||
|
|
||||||
res->embedding = std::vector<float>(n_embd, 0.0f);
|
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
common_embd_normalize(embd, embd_res.data(), n_embd);
|
common_embd_normalize(embd, embd_res.data(), n_embd);
|
||||||
res->embedding = embd_res;
|
res->embedding.push_back(embd_res);
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "%s", "sending embeddings\n");
|
SLT_DBG(slot, "%s", "sending embeddings\n");
|
||||||
|
@ -2659,7 +2667,10 @@ struct server_context {
|
||||||
|
|
||||||
// add prompt tokens for processing in the current batch
|
// add prompt tokens for processing in the current batch
|
||||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
|
// without pooling, we want to output the embeddings for all the tokens in the batch
|
||||||
|
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
|
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||||
|
|
|
@ -74,6 +74,18 @@ def test_embedding_mixed_input(content, is_multi_prompt: bool):
|
||||||
assert len(res.body['embedding']) > 1
|
assert len(res.body['embedding']) > 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_pooling_none():
|
||||||
|
server = ServerPreset.bert_bge_small(pooling = 'none')
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/embeddings", data={
|
||||||
|
"input": "hello hello hello",
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert len(res.body['data']) == 1
|
||||||
|
assert 'embedding' in res.body['data'][0]
|
||||||
|
assert len(res.body['data'][0]['embedding']) == 3
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_openai_library_single():
|
def test_embedding_openai_library_single():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
|
|
|
@ -65,6 +65,7 @@ class ServerProcess:
|
||||||
server_reranking: bool | None = False
|
server_reranking: bool | None = False
|
||||||
server_metrics: bool | None = False
|
server_metrics: bool | None = False
|
||||||
server_slots: bool | None = False
|
server_slots: bool | None = False
|
||||||
|
pooling: str | None = None
|
||||||
draft: int | None = None
|
draft: int | None = None
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
response_format: str | None = None
|
response_format: str | None = None
|
||||||
|
@ -132,6 +133,8 @@ class ServerProcess:
|
||||||
server_args.append("--metrics")
|
server_args.append("--metrics")
|
||||||
if self.server_slots:
|
if self.server_slots:
|
||||||
server_args.append("--slots")
|
server_args.append("--slots")
|
||||||
|
if self.pooling:
|
||||||
|
server_args.extend(["--pooling", self.pooling])
|
||||||
if self.model_alias:
|
if self.model_alias:
|
||||||
server_args.extend(["--alias", self.model_alias])
|
server_args.extend(["--alias", self.model_alias])
|
||||||
if self.n_ctx:
|
if self.n_ctx:
|
||||||
|
@ -272,7 +275,7 @@ class ServerPreset:
|
||||||
return server
|
return server
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bert_bge_small() -> ServerProcess:
|
def bert_bge_small(pooling = 'last') -> ServerProcess:
|
||||||
server = ServerProcess()
|
server = ServerProcess()
|
||||||
server.model_hf_repo = "ggml-org/models"
|
server.model_hf_repo = "ggml-org/models"
|
||||||
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
||||||
|
@ -283,6 +286,7 @@ class ServerPreset:
|
||||||
server.n_slots = 2
|
server.n_slots = 2
|
||||||
server.seed = 42
|
server.seed = 42
|
||||||
server.server_embeddings = True
|
server.server_embeddings = True
|
||||||
|
server.pooling = pooling
|
||||||
return server
|
return server
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue