llama : fix embeddings

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-02-29 15:39:10 +02:00
parent a0fc62661f
commit d0347840c1
No known key found for this signature in database
GPG key ID: BF970631944C16B7
6 changed files with 127 additions and 62 deletions

View file

@ -19,7 +19,7 @@ static std::vector<std::string> split_lines(const std::string & s) {
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
}
}
@ -45,9 +45,13 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
// normalize on copy
for (int k = 0; k < n_seq; k++) {
float * emb = llama_get_embeddings_ith(ctx, k);
float * out = output + k * n_embd;
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
}
float * emb = llama_get_embeddings_ith(ctx, i);
float * out = output + batch.seq_id[i][0] * n_embd;
normalize(emb, out, n_embd);
}
}
@ -145,6 +149,7 @@ int main(int argc, char ** argv) {
for (int k = 0; k < n_prompts; k++) {
// clamp to n_batch tokens
auto & inp = inputs[k];
const uint64_t n_toks = inp.size();
// encode if at capacity