rebase to new embed

This commit is contained in:
Douglas Hanley 2024-03-05 23:23:17 -06:00
parent 805ae529c4
commit 97936078b7
3 changed files with 18 additions and 20 deletions

View file

@ -39,24 +39,23 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
// testing with and without EOS - unexpected embeddings in both cases - GritLM seems to have EOS = ""
std::string input_string = instruction + sentences[i];
auto inputs = llama_tokenize(mdl, input_string, true, false);
uint64_t n_toks = inputs.size();
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L116
// inputs.push_back(llama_token_eos(mdl));
// we want to ignore instruction tokens for mean pooling
auto inputs_instruct = llama_tokenize(mdl, instruction, true, false);
int n_inst = inputs_instruct.size();
uint64_t n_inst = inputs_instruct.size();
/*/
// debug tokens - these are matching as referenced in their sample so doesn't appear to be a token issue
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
});
std::printf("\n");
*/
// add input to batch (this increments n_tokens)
for (uint64_t j = 0; j < inputs.size(); j++) {
llama_batch_add(batch, inputs[j], j, { 0 }, false);
for (uint64_t j = 0; j < n_toks; j++) {
llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
}
// clear previous kv_cache values (irrelevant for embeddings)
@ -66,23 +65,22 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
llama_decode(ctx, batch);
// get embedding dimensions
int n_toks = inputs.size();
int n_embd = llama_n_embd(mdl);
uint64_t n_embd = llama_n_embd(mdl);
// allocate embedding output
std::vector<float> emb_unorm(n_embd, 0.0f);
// sum up all token embeddings
for (int k = n_inst; k < n_toks; k++) {
for (uint64_t k = n_inst; k < n_toks; k++) {
float * emb = llama_get_embeddings_ith(ctx, k);
for (int j = 0; j < n_embd; j++) {
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] += emb[j];
}
}
// divide by number of tokens (mean pooling)
int n_sent = n_toks - n_inst;
for (int j = 0; j < n_embd; j++) {
uint64_t n_sent = n_toks - n_inst;
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] /= n_sent;
}
@ -90,14 +88,12 @@ static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vec
normalize(emb_unorm, emb_norm.data());
result.push_back(emb_norm);
/*
// print out emb_norm
std::printf("embedding %ld: ", i);
for (int j = 0; j < n_embd; j++) {
for (uint64_t j = 0; j < 20; j++) {
std::printf("%.5f ", emb_norm[j]);
}
std::printf("\n");
*/
std::printf("\n\n");
llama_batch_free(batch);
}
@ -124,14 +120,14 @@ int main(int argc, char* argv[])
);
return true;
};
cparams.embedding = true;
cparams.embeddings = true;
cparams.causal_attn = false;
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
llama_backend_init();
auto mdl = llama_load_model_from_file(params.model.c_str(), mparams);
auto ctx = llama_new_context_with_model(mdl, cparams);
auto bat = llama_batch_init(llama_n_ctx(ctx), 0, 1);
// ### Embedding/Representation ### taken sample from here:
// https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic
@ -167,7 +163,6 @@ int main(int argc, char* argv[])
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
}
llama_batch_free(bat);
llama_free(ctx);
llama_free_model(mdl);
llama_backend_free();