embd : minor improvements

This commit is contained in:
Georgi Gerganov 2024-02-13 13:52:50 +02:00
parent f281d76f41
commit b650d4cbdf
No known key found for this signature in database
GPG key ID: BF970631944C16B7
3 changed files with 17 additions and 22 deletions

View file

@ -18,16 +18,8 @@ static std::vector<std::string> split_lines(const std::string & s) {
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
const uint64_t n_tokens = tokens.size();
int n_past = batch.n_tokens;
batch.n_tokens += n_tokens;
for (uint64_t i = 0; i < n_tokens; i++) {
uint64_t j = n_past + i;
batch.token[j] = tokens[i];
batch.pos[j] = i;
batch.n_seq_id[j] = 1;
batch.seq_id[j][0] = seq_id;
batch.logits[j] = 0;
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
}
}
@ -158,7 +150,7 @@ int main(int argc, char ** argv) {
if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
batch.n_tokens = 0;
llama_batch_clear(batch);
p += s;
s = 0;
}
@ -172,10 +164,13 @@ int main(int argc, char ** argv) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
// print first embedding
fprintf(stderr, "\nfirst embedding:\n");
// print first 3 embeddings
for (int j = 0; j < std::min(3, n_prompts); j++) {
fprintf(stderr, "embedding %d: ", j);
for (int i = 0; i < n_embd; i++) {
fprintf(stderr, "%f ", emb[i]);
fprintf(stderr, "%f ", emb[j * n_embd + i]);
}
fprintf(stderr, "\n\n");
}
fprintf(stderr, "\n");

View file

@ -5826,7 +5826,7 @@ struct llm_build_context {
if (do_pooling) {
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_sum);
}
cb(cur, "result_embed", -1);
cb(cur, "result_embd", -1);
ggml_build_forward_expand(gf, cur);
@ -7516,7 +7516,7 @@ static int llama_decode_internal(
embeddings = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
}
} else if (strcmp(res->name, "result_embed") == 0) {
} else if (strcmp(res->name, "result_embd") == 0) {
embeddings = res;
res = nullptr;
} else {
@ -7636,12 +7636,12 @@ static int llama_decode_internal(
if (!lctx.embedding.empty()) {
auto & embedding_out = lctx.embedding;
const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0;
const int64_t embed_size = res ? n_embd : n_embd * n_tokens;
const int64_t embd_pos = res ? n_embd * (n_tokens-1) : 0;
const int64_t embd_size = res ? n_embd : n_embd * n_tokens;
embedding_out.resize(embed_size);
embedding_out.resize(embd_size);
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), embed_size*sizeof(float));
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float));
ggml_backend_synchronize(embeddings_backend);
}

View file

@ -629,7 +629,7 @@ extern "C" {
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Get the embeddings for the ith token
// Get the embeddings for the ith sequence
// llama_get_embeddings(ctx) + i*n_embd
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);