server : output embeddings for all tokens when pooling = none
ggml-ci
This commit is contained in:
parent
44eeb6a88e
commit
07946a3a30
3 changed files with 32 additions and 5 deletions
|
@ -727,7 +727,7 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
|
||||
struct server_task_result_embd : server_task_result {
|
||||
int index = 0;
|
||||
std::vector<float> embedding;
|
||||
std::vector<std::vector<float>> embedding;
|
||||
|
||||
int32_t n_tokens;
|
||||
|
||||
|
@ -736,6 +736,14 @@ struct server_task_result_embd : server_task_result {
|
|||
}
|
||||
|
||||
virtual json to_json() override {
|
||||
if (embedding.size() == 1){
|
||||
// to be OAI compatible
|
||||
return json {
|
||||
{"index", index},
|
||||
{"embedding", embedding[0]},
|
||||
};
|
||||
}
|
||||
|
||||
return json {
|
||||
{"index", index},
|
||||
{"embedding", embedding},
|
||||
|
@ -2040,12 +2048,12 @@ struct server_context {
|
|||
if (embd == NULL) {
|
||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
||||
|
||||
res->embedding = std::vector<float>(n_embd, 0.0f);
|
||||
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||
continue;
|
||||
}
|
||||
|
||||
common_embd_normalize(embd, embd_res.data(), n_embd);
|
||||
res->embedding = embd_res;
|
||||
res->embedding.push_back(embd_res);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "%s", "sending embeddings\n");
|
||||
|
@ -2659,7 +2667,10 @@ struct server_context {
|
|||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
|
||||
// without pooling, we want to output the embeddings for all the tokens in the batch
|
||||
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||
|
|
|
@ -74,6 +74,18 @@ def test_embedding_mixed_input(content, is_multi_prompt: bool):
|
|||
assert len(res.body['embedding']) > 1
|
||||
|
||||
|
||||
def test_embedding_pooling_none():
|
||||
server = ServerPreset.bert_bge_small(pooling = 'none')
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
"input": "hello hello hello",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 1
|
||||
assert 'embedding' in res.body['data'][0]
|
||||
assert len(res.body['data'][0]['embedding']) == 3
|
||||
|
||||
|
||||
def test_embedding_openai_library_single():
|
||||
global server
|
||||
server.start()
|
||||
|
|
|
@ -65,6 +65,7 @@ class ServerProcess:
|
|||
server_reranking: bool | None = False
|
||||
server_metrics: bool | None = False
|
||||
server_slots: bool | None = False
|
||||
pooling: str | None = None
|
||||
draft: int | None = None
|
||||
api_key: str | None = None
|
||||
response_format: str | None = None
|
||||
|
@ -132,6 +133,8 @@ class ServerProcess:
|
|||
server_args.append("--metrics")
|
||||
if self.server_slots:
|
||||
server_args.append("--slots")
|
||||
if self.pooling:
|
||||
server_args.extend(["--pooling", self.pooling])
|
||||
if self.model_alias:
|
||||
server_args.extend(["--alias", self.model_alias])
|
||||
if self.n_ctx:
|
||||
|
@ -272,7 +275,7 @@ class ServerPreset:
|
|||
return server
|
||||
|
||||
@staticmethod
|
||||
def bert_bge_small() -> ServerProcess:
|
||||
def bert_bge_small(pooling = 'last') -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
||||
|
@ -283,6 +286,7 @@ class ServerPreset:
|
|||
server.n_slots = 2
|
||||
server.seed = 42
|
||||
server.server_embeddings = True
|
||||
server.pooling = pooling
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue