embd : minor improvements

This commit is contained in:
Georgi Gerganov 2024-02-13 13:52:50 +02:00
parent f281d76f41
commit b650d4cbdf
No known key found for this signature in database
GPG key ID: BF970631944C16B7
3 changed files with 17 additions and 22 deletions

View file

@ -18,16 +18,8 @@ static std::vector<std::string> split_lines(const std::string & s) {
} }
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) { static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
const uint64_t n_tokens = tokens.size(); for (size_t i = 0; i < tokens.size(); i++) {
int n_past = batch.n_tokens; llama_batch_add(batch, tokens[i], i, { seq_id }, false);
batch.n_tokens += n_tokens;
for (uint64_t i = 0; i < n_tokens; i++) {
uint64_t j = n_past + i;
batch.token[j] = tokens[i];
batch.pos[j] = i;
batch.n_seq_id[j] = 1;
batch.seq_id[j][0] = seq_id;
batch.logits[j] = 0;
} }
} }
@ -158,7 +150,7 @@ int main(int argc, char ** argv) {
if (batch.n_tokens + n_toks > n_batch) { if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd; float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd); batch_decode(ctx, batch, out, s, n_embd);
batch.n_tokens = 0; llama_batch_clear(batch);
p += s; p += s;
s = 0; s = 0;
} }
@ -172,10 +164,13 @@ int main(int argc, char ** argv) {
float * out = emb + p * n_embd; float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd); batch_decode(ctx, batch, out, s, n_embd);
// print first embedding // print first 3 embeddings
fprintf(stderr, "\nfirst embedding:\n"); for (int j = 0; j < std::min(3, n_prompts); j++) {
for (int i = 0; i < n_embd; i++) { fprintf(stderr, "embedding %d: ", j);
fprintf(stderr, "%f ", emb[i]); for (int i = 0; i < n_embd; i++) {
fprintf(stderr, "%f ", emb[j * n_embd + i]);
}
fprintf(stderr, "\n\n");
} }
fprintf(stderr, "\n"); fprintf(stderr, "\n");

View file

@ -5826,7 +5826,7 @@ struct llm_build_context {
if (do_pooling) { if (do_pooling) {
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_sum); cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_sum);
} }
cb(cur, "result_embed", -1); cb(cur, "result_embd", -1);
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
@ -7516,7 +7516,7 @@ static int llama_decode_internal(
embeddings = gf->nodes[gf->n_nodes - 3]; embeddings = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
} }
} else if (strcmp(res->name, "result_embed") == 0) { } else if (strcmp(res->name, "result_embd") == 0) {
embeddings = res; embeddings = res;
res = nullptr; res = nullptr;
} else { } else {
@ -7636,12 +7636,12 @@ static int llama_decode_internal(
if (!lctx.embedding.empty()) { if (!lctx.embedding.empty()) {
auto & embedding_out = lctx.embedding; auto & embedding_out = lctx.embedding;
const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0; const int64_t embd_pos = res ? n_embd * (n_tokens-1) : 0;
const int64_t embed_size = res ? n_embd : n_embd * n_tokens; const int64_t embd_size = res ? n_embd : n_embd * n_tokens;
embedding_out.resize(embed_size); embedding_out.resize(embd_size);
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), embed_size*sizeof(float)); ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float));
ggml_backend_synchronize(embeddings_backend); ggml_backend_synchronize(embeddings_backend);
} }

View file

@ -629,7 +629,7 @@ extern "C" {
// shape: [n_embd] (1-dimensional) // shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Get the embeddings for the ith token // Get the embeddings for the ith sequence
// llama_get_embeddings(ctx) + i*n_embd // llama_get_embeddings(ctx) + i*n_embd
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);