From b650d4cbdf85deebeb1d5d3de35e6f58b1b16139 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 13 Feb 2024 13:52:50 +0200 Subject: [PATCH] embd : minor improvements --- examples/embedding/embedding.cpp | 25 ++++++++++--------------- llama.cpp | 12 ++++++------ llama.h | 2 +- 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 20a31a2fe..b4688cf51 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -18,16 +18,8 @@ static std::vector split_lines(const std::string & s) { } static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id) { - const uint64_t n_tokens = tokens.size(); - int n_past = batch.n_tokens; - batch.n_tokens += n_tokens; - for (uint64_t i = 0; i < n_tokens; i++) { - uint64_t j = n_past + i; - batch.token[j] = tokens[i]; - batch.pos[j] = i; - batch.n_seq_id[j] = 1; - batch.seq_id[j][0] = seq_id; - batch.logits[j] = 0; + for (size_t i = 0; i < tokens.size(); i++) { + llama_batch_add(batch, tokens[i], i, { seq_id }, false); } } @@ -158,7 +150,7 @@ int main(int argc, char ** argv) { if (batch.n_tokens + n_toks > n_batch) { float * out = emb + p * n_embd; batch_decode(ctx, batch, out, s, n_embd); - batch.n_tokens = 0; + llama_batch_clear(batch); p += s; s = 0; } @@ -172,10 +164,13 @@ int main(int argc, char ** argv) { float * out = emb + p * n_embd; batch_decode(ctx, batch, out, s, n_embd); - // print first embedding - fprintf(stderr, "\nfirst embedding:\n"); - for (int i = 0; i < n_embd; i++) { - fprintf(stderr, "%f ", emb[i]); + // print first 3 embeddings + for (int j = 0; j < std::min(3, n_prompts); j++) { + fprintf(stderr, "embedding %d: ", j); + for (int i = 0; i < n_embd; i++) { + fprintf(stderr, "%f ", emb[j * n_embd + i]); + } + fprintf(stderr, "\n\n"); } fprintf(stderr, "\n"); diff --git a/llama.cpp b/llama.cpp index ec2a42736..62fdc979e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5826,7 +5826,7 @@ struct llm_build_context { if (do_pooling) { cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_sum); } - cb(cur, "result_embed", -1); + cb(cur, "result_embd", -1); ggml_build_forward_expand(gf, cur); @@ -7516,7 +7516,7 @@ static int llama_decode_internal( embeddings = gf->nodes[gf->n_nodes - 3]; GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); } - } else if (strcmp(res->name, "result_embed") == 0) { + } else if (strcmp(res->name, "result_embd") == 0) { embeddings = res; res = nullptr; } else { @@ -7636,12 +7636,12 @@ static int llama_decode_internal( if (!lctx.embedding.empty()) { auto & embedding_out = lctx.embedding; - const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0; - const int64_t embed_size = res ? n_embd : n_embd * n_tokens; + const int64_t embd_pos = res ? n_embd * (n_tokens-1) : 0; + const int64_t embd_size = res ? n_embd : n_embd * n_tokens; - embedding_out.resize(embed_size); + embedding_out.resize(embd_size); ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); - ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), embed_size*sizeof(float)); + ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float)); ggml_backend_synchronize(embeddings_backend); } diff --git a/llama.h b/llama.h index 0d4bae798..5ef78ec96 100644 --- a/llama.h +++ b/llama.h @@ -629,7 +629,7 @@ extern "C" { // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); - // Get the embeddings for the ith token + // Get the embeddings for the ith sequence // llama_get_embeddings(ctx) + i*n_embd LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);