llama : fix embeddings
ggml-ci
This commit is contained in:
parent
a0fc62661f
commit
d0347840c1
6 changed files with 127 additions and 62 deletions
|
@ -19,7 +19,7 @@ static std::vector<std::string> split_lines(const std::string & s) {
|
|||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,9 +45,13 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||
}
|
||||
|
||||
// normalize on copy
|
||||
for (int k = 0; k < n_seq; k++) {
|
||||
float * emb = llama_get_embeddings_ith(ctx, k);
|
||||
float * out = output + k * n_embd;
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch.logits[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float * emb = llama_get_embeddings_ith(ctx, i);
|
||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||
normalize(emb, out, n_embd);
|
||||
}
|
||||
}
|
||||
|
@ -145,6 +149,7 @@ int main(int argc, char ** argv) {
|
|||
for (int k = 0; k < n_prompts; k++) {
|
||||
// clamp to n_batch tokens
|
||||
auto & inp = inputs[k];
|
||||
|
||||
const uint64_t n_toks = inp.size();
|
||||
|
||||
// encode if at capacity
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue