server : do not normalize embeddings when there is no pooling

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-12-17 13:36:32 +02:00
parent abf33e2017
commit 7e693f92d7
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 20 additions and 6 deletions

View file

@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
break; break;
case 0: // max absolute case 0: // max absolute
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]); if (sum < std::abs(inp[i])) {
sum = std::abs(inp[i]);
}
} }
sum /= 32760.0; // make an int16 range sum /= 32760.0; // make an int16 range
break; break;

View file

@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
// Embedding utils // Embedding utils
// //
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2); // TODO: repace embd_norm with an enum
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n); float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);

View file

@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
} }
std::vector<float> emb_norm(emb_unorm.size()); std::vector<float> emb_norm(emb_unorm.size());
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd); common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
result.push_back(emb_norm); result.push_back(emb_norm);
#ifdef GRIT_DEBUG #ifdef GRIT_DEBUG

View file

@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
} }
float * out = output + batch.seq_id[i][0] * n_embd; float * out = output + batch.seq_id[i][0] * n_embd;
common_embd_normalize(embd, out, n_embd); common_embd_normalize(embd, out, n_embd, 2);
} }
} }

View file

@ -2059,8 +2059,14 @@ struct server_context {
continue; continue;
} }
common_embd_normalize(embd, embd_res.data(), n_embd); // normalize only when there is pooling
res->embedding.push_back(embd_res); // TODO: configurable
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
res->embedding.push_back(embd_res);
} else {
res->embedding.push_back({ embd, embd + n_embd });
}
} }
SLT_DBG(slot, "%s", "sending embeddings\n"); SLT_DBG(slot, "%s", "sending embeddings\n");

View file

@ -87,6 +87,10 @@ def test_embedding_pooling_none():
assert 'embedding' in res.body[0] assert 'embedding' in res.body[0]
assert len(res.body[0]['embedding']) == 3 assert len(res.body[0]['embedding']) == 3
# make sure embedding vector is not normalized
for x in res.body[0]['embedding']:
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
def test_embedding_pooling_none_oai(): def test_embedding_pooling_none_oai():
global server global server
@ -95,6 +99,7 @@ def test_embedding_pooling_none_oai():
res = server.make_request("POST", "/v1/embeddings", data={ res = server.make_request("POST", "/v1/embeddings", data={
"input": "hello hello hello", "input": "hello hello hello",
}) })
# /v1/embeddings does not support pooling type 'none' # /v1/embeddings does not support pooling type 'none'
assert res.status_code == 400 assert res.status_code == 400