rebase to new embed
This commit is contained in:
parent
805ae529c4
commit
97936078b7
3 changed files with 18 additions and 20 deletions
|
@ -39,24 +39,23 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
|
|||
// testing with and without EOS - unexpected embeddings in both cases - GritLM seems to have EOS = ""
|
||||
std::string input_string = instruction + sentences[i];
|
||||
auto inputs = llama_tokenize(mdl, input_string, true, false);
|
||||
uint64_t n_toks = inputs.size();
|
||||
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L116
|
||||
// inputs.push_back(llama_token_eos(mdl));
|
||||
|
||||
// we want to ignore instruction tokens for mean pooling
|
||||
auto inputs_instruct = llama_tokenize(mdl, instruction, true, false);
|
||||
int n_inst = inputs_instruct.size();
|
||||
uint64_t n_inst = inputs_instruct.size();
|
||||
|
||||
/*/
|
||||
// debug tokens - these are matching as referenced in their sample so doesn't appear to be a token issue
|
||||
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
|
||||
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
|
||||
});
|
||||
std::printf("\n");
|
||||
*/
|
||||
|
||||
// add input to batch (this increments n_tokens)
|
||||
for (uint64_t j = 0; j < inputs.size(); j++) {
|
||||
llama_batch_add(batch, inputs[j], j, { 0 }, false);
|
||||
for (uint64_t j = 0; j < n_toks; j++) {
|
||||
llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
|
||||
}
|
||||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
|
@ -66,23 +65,22 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
|
|||
llama_decode(ctx, batch);
|
||||
|
||||
// get embedding dimensions
|
||||
int n_toks = inputs.size();
|
||||
int n_embd = llama_n_embd(mdl);
|
||||
uint64_t n_embd = llama_n_embd(mdl);
|
||||
|
||||
// allocate embedding output
|
||||
std::vector<float> emb_unorm(n_embd, 0.0f);
|
||||
|
||||
// sum up all token embeddings
|
||||
for (int k = n_inst; k < n_toks; k++) {
|
||||
for (uint64_t k = n_inst; k < n_toks; k++) {
|
||||
float * emb = llama_get_embeddings_ith(ctx, k);
|
||||
for (int j = 0; j < n_embd; j++) {
|
||||
for (uint64_t j = 0; j < n_embd; j++) {
|
||||
emb_unorm[j] += emb[j];
|
||||
}
|
||||
}
|
||||
|
||||
// divide by number of tokens (mean pooling)
|
||||
int n_sent = n_toks - n_inst;
|
||||
for (int j = 0; j < n_embd; j++) {
|
||||
uint64_t n_sent = n_toks - n_inst;
|
||||
for (uint64_t j = 0; j < n_embd; j++) {
|
||||
emb_unorm[j] /= n_sent;
|
||||
}
|
||||
|
||||
|
@ -90,14 +88,12 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
|
|||
normalize(emb_unorm, emb_norm.data());
|
||||
result.push_back(emb_norm);
|
||||
|
||||
/*
|
||||
// print out emb_norm
|
||||
std::printf("embedding %ld: ", i);
|
||||
for (int j = 0; j < n_embd; j++) {
|
||||
for (uint64_t j = 0; j < 20; j++) {
|
||||
std::printf("%.5f ", emb_norm[j]);
|
||||
}
|
||||
std::printf("\n");
|
||||
*/
|
||||
std::printf("\n\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
@ -124,14 +120,14 @@ int main(int argc, char* argv[])
|
|||
);
|
||||
return true;
|
||||
};
|
||||
cparams.embedding = true;
|
||||
cparams.embeddings = true;
|
||||
cparams.causal_attn = false;
|
||||
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
llama_backend_init();
|
||||
|
||||
auto mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||
auto ctx = llama_new_context_with_model(mdl, cparams);
|
||||
auto bat = llama_batch_init(llama_n_ctx(ctx), 0, 1);
|
||||
|
||||
// ### Embedding/Representation ### taken sample from here:
|
||||
// https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic
|
||||
|
@ -167,7 +163,6 @@ int main(int argc, char* argv[])
|
|||
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
|
||||
}
|
||||
|
||||
llama_batch_free(bat);
|
||||
llama_free(ctx);
|
||||
llama_free_model(mdl);
|
||||
llama_backend_free();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue