server : normalize embeddings (#5956)

* output normalize embedding in '/v1/embeddings'

* common : reuse llama_embd_normalize

* common : better normalize impl

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
SeungWon Jeong 2024-03-09 21:27:58 +09:00 committed by GitHub
parent 2c4f566c88
commit fb215c3832
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 30 additions and 14 deletions

View file

@ -23,17 +23,6 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}
}
static void normalize(const float * vec, float * out, int n) {
float norm = 0;
for (int i = 0; i < n; i++) {
norm += vec[i] * vec[i];
}
norm = sqrt(norm);
for (int i = 0; i < n; i++) {
out[i] = vec[i] / norm;
}
}
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
@ -44,7 +33,6 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
fprintf(stderr, "%s : failed to decode\n", __func__);
}
// normalize on copy
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
@ -61,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
float * out = output + batch.seq_id[i][0] * n_embd;
normalize(embd, out, n_embd);
llama_embd_normalize(embd, out, n_embd);
}
}