embd : minor improvements
This commit is contained in:
parent
f281d76f41
commit
b650d4cbdf
3 changed files with 17 additions and 22 deletions
|
@ -18,16 +18,8 @@ static std::vector<std::string> split_lines(const std::string & s) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
|
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
|
||||||
const uint64_t n_tokens = tokens.size();
|
for (size_t i = 0; i < tokens.size(); i++) {
|
||||||
int n_past = batch.n_tokens;
|
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
|
||||||
batch.n_tokens += n_tokens;
|
|
||||||
for (uint64_t i = 0; i < n_tokens; i++) {
|
|
||||||
uint64_t j = n_past + i;
|
|
||||||
batch.token[j] = tokens[i];
|
|
||||||
batch.pos[j] = i;
|
|
||||||
batch.n_seq_id[j] = 1;
|
|
||||||
batch.seq_id[j][0] = seq_id;
|
|
||||||
batch.logits[j] = 0;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,7 +150,7 @@ int main(int argc, char ** argv) {
|
||||||
if (batch.n_tokens + n_toks > n_batch) {
|
if (batch.n_tokens + n_toks > n_batch) {
|
||||||
float * out = emb + p * n_embd;
|
float * out = emb + p * n_embd;
|
||||||
batch_decode(ctx, batch, out, s, n_embd);
|
batch_decode(ctx, batch, out, s, n_embd);
|
||||||
batch.n_tokens = 0;
|
llama_batch_clear(batch);
|
||||||
p += s;
|
p += s;
|
||||||
s = 0;
|
s = 0;
|
||||||
}
|
}
|
||||||
|
@ -172,10 +164,13 @@ int main(int argc, char ** argv) {
|
||||||
float * out = emb + p * n_embd;
|
float * out = emb + p * n_embd;
|
||||||
batch_decode(ctx, batch, out, s, n_embd);
|
batch_decode(ctx, batch, out, s, n_embd);
|
||||||
|
|
||||||
// print first embedding
|
// print first 3 embeddings
|
||||||
fprintf(stderr, "\nfirst embedding:\n");
|
for (int j = 0; j < std::min(3, n_prompts); j++) {
|
||||||
|
fprintf(stderr, "embedding %d: ", j);
|
||||||
for (int i = 0; i < n_embd; i++) {
|
for (int i = 0; i < n_embd; i++) {
|
||||||
fprintf(stderr, "%f ", emb[i]);
|
fprintf(stderr, "%f ", emb[j * n_embd + i]);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "\n\n");
|
||||||
}
|
}
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
|
|
12
llama.cpp
12
llama.cpp
|
@ -5826,7 +5826,7 @@ struct llm_build_context {
|
||||||
if (do_pooling) {
|
if (do_pooling) {
|
||||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_sum);
|
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_sum);
|
||||||
}
|
}
|
||||||
cb(cur, "result_embed", -1);
|
cb(cur, "result_embd", -1);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
@ -7516,7 +7516,7 @@ static int llama_decode_internal(
|
||||||
embeddings = gf->nodes[gf->n_nodes - 3];
|
embeddings = gf->nodes[gf->n_nodes - 3];
|
||||||
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
||||||
}
|
}
|
||||||
} else if (strcmp(res->name, "result_embed") == 0) {
|
} else if (strcmp(res->name, "result_embd") == 0) {
|
||||||
embeddings = res;
|
embeddings = res;
|
||||||
res = nullptr;
|
res = nullptr;
|
||||||
} else {
|
} else {
|
||||||
|
@ -7636,12 +7636,12 @@ static int llama_decode_internal(
|
||||||
if (!lctx.embedding.empty()) {
|
if (!lctx.embedding.empty()) {
|
||||||
auto & embedding_out = lctx.embedding;
|
auto & embedding_out = lctx.embedding;
|
||||||
|
|
||||||
const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0;
|
const int64_t embd_pos = res ? n_embd * (n_tokens-1) : 0;
|
||||||
const int64_t embed_size = res ? n_embd : n_embd * n_tokens;
|
const int64_t embd_size = res ? n_embd : n_embd * n_tokens;
|
||||||
|
|
||||||
embedding_out.resize(embed_size);
|
embedding_out.resize(embd_size);
|
||||||
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
|
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
|
||||||
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), embed_size*sizeof(float));
|
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float));
|
||||||
ggml_backend_synchronize(embeddings_backend);
|
ggml_backend_synchronize(embeddings_backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
2
llama.h
2
llama.h
|
@ -629,7 +629,7 @@ extern "C" {
|
||||||
// shape: [n_embd] (1-dimensional)
|
// shape: [n_embd] (1-dimensional)
|
||||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||||
|
|
||||||
// Get the embeddings for the ith token
|
// Get the embeddings for the ith sequence
|
||||||
// llama_get_embeddings(ctx) + i*n_embd
|
// llama_get_embeddings(ctx) + i*n_embd
|
||||||
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue