common : reuse llama_embd_normalize

This commit is contained in:
Georgi Gerganov 2024-03-09 14:23:50 +02:00
parent 08d2ea1edb
commit 98cccf14e3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 29 additions and 30 deletions

View file

@ -1852,3 +1852,19 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
printf("\n=== Done dumping\n"); printf("\n=== Done dumping\n");
} }
void llama_embd_normalize(const float * inp, float * out, int n) {
float norm = 0;
for (int i = 0; i < n; i++) {
norm += inp[i] * inp[i];
}
norm = sqrt(norm);
if (norm == 0) {
return;
}
norm = 1.0 / norm;
for (int i = 0; i < n; i++) {
out[i] = inp[i] * norm;
}
}

View file

@ -260,3 +260,10 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
// Dump the KV cache view showing individual sequences in each cell (long output). // Dump the KV cache view showing individual sequences in each cell (long output).
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
//
// Embedding utils
//
void llama_embd_normalize(const float * inp, float * out, int n);

View file

@ -23,17 +23,6 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
} }
} }
static void normalize(const float * vec, float * out, int n) {
float norm = 0;
for (int i = 0; i < n; i++) {
norm += vec[i] * vec[i];
}
norm = sqrt(norm);
for (int i = 0; i < n; i++) {
out[i] = vec[i] / norm;
}
}
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
@ -44,7 +33,6 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
fprintf(stderr, "%s : failed to decode\n", __func__); fprintf(stderr, "%s : failed to decode\n", __func__);
} }
// normalize on copy
for (int i = 0; i < batch.n_tokens; i++) { for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) { if (!batch.logits[i]) {
continue; continue;
@ -61,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
} }
float * out = output + batch.seq_id[i][0] * n_embd; float * out = output + batch.seq_id[i][0] * n_embd;
normalize(embd, out, n_embd); llama_embd_normalize(embd, out, n_embd);
} }
} }

View file

@ -1318,6 +1318,8 @@ struct server_context {
const int n_embd = llama_n_embd(model); const int n_embd = llama_n_embd(model);
std::vector<float> embd_res(n_embd, 0.0f);
for (int i = 0; i < batch.n_tokens; ++i) { for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
continue; continue;
@ -1341,8 +1343,10 @@ struct server_context {
continue; continue;
} }
llama_embd_normalize(embd, embd_res.data(), n_embd);
res.data = json { res.data = json {
{"embedding", std::vector<float>(embd, embd + n_embd)}, {"embedding", embd_res},
}; };
} }
@ -2651,17 +2655,6 @@ inline void signal_handler(int signal) {
shutdown_handler(signal); shutdown_handler(signal);
} }
static void normalize(std::vector<float>& vec) {
float norm = 0;
for (float val : vec) {
norm += val * val;
}
norm = sqrt(norm);
for (float& val : vec) {
val /= norm;
}
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
#if SERVER_VERBOSE != 1 #if SERVER_VERBOSE != 1
log_disable(); log_disable();
@ -3357,11 +3350,6 @@ int main(int argc, char ** argv) {
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
// normalize the embedding
std::vector<float> embedding = json_value(result.data, "embedding", json::array());
normalize(embedding);
result.data["embedding"] = embedding;
// append to the responses // append to the responses
responses.push_back(result.data); responses.push_back(result.data);
} }