common : better normalize impl

This commit is contained in:
Georgi Gerganov 2024-03-09 14:26:28 +02:00
parent 98cccf14e3
commit 02addabbae
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -1854,15 +1854,14 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
} }
void llama_embd_normalize(const float * inp, float * out, int n) { void llama_embd_normalize(const float * inp, float * out, int n) {
float norm = 0; double sum = 0.0;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
norm += inp[i] * inp[i]; sum += inp[i] * inp[i];
} }
norm = sqrt(norm); sum = sqrt(sum);
if (norm == 0) {
return; const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
}
norm = 1.0 / norm;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
out[i] = inp[i] * norm; out[i] = inp[i] * norm;
} }