server : do not normalize embeddings when there is no pooling
ggml-ci
This commit is contained in:
parent
abf33e2017
commit
7e693f92d7
6 changed files with 20 additions and 6 deletions
|
@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
|
|||
break;
|
||||
case 0: // max absolute
|
||||
for (int i = 0; i < n; i++) {
|
||||
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
|
||||
if (sum < std::abs(inp[i])) {
|
||||
sum = std::abs(inp[i]);
|
||||
}
|
||||
}
|
||||
sum /= 32760.0; // make an int16 range
|
||||
break;
|
||||
|
|
|
@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
|
|||
// Embedding utils
|
||||
//
|
||||
|
||||
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
|
||||
// TODO: repace embd_norm with an enum
|
||||
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
|
||||
|
||||
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
}
|
||||
|
||||
std::vector<float> emb_norm(emb_unorm.size());
|
||||
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
|
||||
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
|
||||
result.push_back(emb_norm);
|
||||
|
||||
#ifdef GRIT_DEBUG
|
||||
|
|
|
@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||
}
|
||||
|
||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||
common_embd_normalize(embd, out, n_embd);
|
||||
common_embd_normalize(embd, out, n_embd, 2);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2059,8 +2059,14 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
common_embd_normalize(embd, embd_res.data(), n_embd);
|
||||
res->embedding.push_back(embd_res);
|
||||
// normalize only when there is pooling
|
||||
// TODO: configurable
|
||||
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
||||
res->embedding.push_back(embd_res);
|
||||
} else {
|
||||
res->embedding.push_back({ embd, embd + n_embd });
|
||||
}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "%s", "sending embeddings\n");
|
||||
|
|
|
@ -87,6 +87,10 @@ def test_embedding_pooling_none():
|
|||
assert 'embedding' in res.body[0]
|
||||
assert len(res.body[0]['embedding']) == 3
|
||||
|
||||
# make sure embedding vector is not normalized
|
||||
for x in res.body[0]['embedding']:
|
||||
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
|
||||
|
||||
|
||||
def test_embedding_pooling_none_oai():
|
||||
global server
|
||||
|
@ -95,6 +99,7 @@ def test_embedding_pooling_none_oai():
|
|||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": "hello hello hello",
|
||||
})
|
||||
|
||||
# /v1/embeddings does not support pooling type 'none'
|
||||
assert res.status_code == 400
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue