Embeddings extraction support
This commit is contained in:
parent
859e70899a
commit
8a3c34bb54
3 changed files with 59 additions and 56 deletions
92
llama.cpp
92
llama.cpp
|
@ -101,9 +101,10 @@ struct llama_context {
|
||||||
|
|
||||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
|
bool logits_all = false;
|
||||||
|
|
||||||
// input embedding (1-dimensional array: [n_embd])
|
// input embedding (1-dimensional array: [n_embd])
|
||||||
std::vector<float> embedding;
|
std::vector<float> embedding;
|
||||||
bool logits_all = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_context_params llama_context_default_params() {
|
struct llama_context_params llama_context_default_params() {
|
||||||
|
@ -114,7 +115,7 @@ struct llama_context_params llama_context_default_params() {
|
||||||
/*.f16_kv =*/ false,
|
/*.f16_kv =*/ false,
|
||||||
/*.logits_all =*/ false,
|
/*.logits_all =*/ false,
|
||||||
/*.vocab_only =*/ false,
|
/*.vocab_only =*/ false,
|
||||||
/*.embedding =*/ false,
|
/*.embedding =*/ false,
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -130,8 +131,7 @@ static bool llama_model_load(
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
int n_parts,
|
int n_parts,
|
||||||
ggml_type memory_type,
|
ggml_type memory_type,
|
||||||
bool vocab_only,
|
bool vocab_only) {
|
||||||
bool embedding) {
|
|
||||||
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||||
|
|
||||||
const int64_t t_start_us = ggml_time_us();
|
const int64_t t_start_us = ggml_time_us();
|
||||||
|
@ -596,29 +596,11 @@ static bool llama_model_load(
|
||||||
fin.close();
|
fin.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.logits.reserve(lctx.model.hparams.n_ctx);
|
|
||||||
|
|
||||||
if (embedding){
|
|
||||||
lctx.embedding.reserve(lctx.model.hparams.n_embd);
|
|
||||||
}
|
|
||||||
|
|
||||||
lctx.t_load_us = ggml_time_us() - t_start_us;
|
lctx.t_load_us = ggml_time_us() - t_start_us;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prints the provided embedding vector to stdout
|
|
||||||
// in a neat format
|
|
||||||
void display_embedding(const std::vector<float> & embedding_representation){
|
|
||||||
fprintf(stdout, "\n[\n");
|
|
||||||
for (int j = 0; j < embedding_representation.size()-1 ; j++){
|
|
||||||
fprintf(stdout, "%f, ", embedding_representation[j]);
|
|
||||||
}
|
|
||||||
fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]);
|
|
||||||
fprintf(stdout, "\n]\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// evaluate the transformer
|
// evaluate the transformer
|
||||||
//
|
//
|
||||||
// - lctx: llama context
|
// - lctx: llama context
|
||||||
|
@ -631,8 +613,7 @@ static bool llama_eval_internal(
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
const int n_tokens,
|
const int n_tokens,
|
||||||
const int n_past,
|
const int n_past,
|
||||||
const int n_threads,
|
const int n_threads) {
|
||||||
const bool embedding_mode = false) {
|
|
||||||
const int64_t t_start_us = ggml_time_us();
|
const int64_t t_start_us = ggml_time_us();
|
||||||
|
|
||||||
const int N = n_tokens;
|
const int N = n_tokens;
|
||||||
|
@ -810,6 +791,9 @@ static bool llama_eval_internal(
|
||||||
inpL = cur;
|
inpL = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// used at the end to optionally extract the embeddings
|
||||||
|
struct ggml_tensor * embeddings = NULL;
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL);
|
||||||
|
@ -818,18 +802,8 @@ static bool llama_eval_internal(
|
||||||
inpL = ggml_mul(ctx0,
|
inpL = ggml_mul(ctx0,
|
||||||
ggml_repeat(ctx0, model.norm, inpL),
|
ggml_repeat(ctx0, model.norm, inpL),
|
||||||
inpL);
|
inpL);
|
||||||
}
|
|
||||||
|
|
||||||
if(embedding_mode){
|
embeddings = inpL;
|
||||||
// capture input sentence embedding
|
|
||||||
ggml_build_forward_expand(&gf, inpL);
|
|
||||||
ggml_graph_compute (ctx0, &gf);
|
|
||||||
std::vector<float> embedding_representation;
|
|
||||||
embedding_representation.resize(n_embd);
|
|
||||||
memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd);
|
|
||||||
display_embedding(embedding_representation);
|
|
||||||
ggml_free(ctx0);
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// lm_head
|
// lm_head
|
||||||
|
@ -852,15 +826,26 @@ static bool llama_eval_internal(
|
||||||
//embd_w.resize(n_vocab*N);
|
//embd_w.resize(n_vocab*N);
|
||||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||||
|
|
||||||
auto & logits_out = lctx.logits;
|
// extract logits
|
||||||
|
{
|
||||||
|
auto & logits_out = lctx.logits;
|
||||||
|
|
||||||
if (lctx.logits_all) {
|
if (lctx.logits_all) {
|
||||||
logits_out.resize(n_vocab * N);
|
logits_out.resize(n_vocab * N);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||||
} else {
|
} else {
|
||||||
// return result for just the last token
|
// return result for just the last token
|
||||||
logits_out.resize(n_vocab);
|
logits_out.resize(n_vocab);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract embeddings
|
||||||
|
if (lctx.embedding.size()) {
|
||||||
|
auto & embedding_out = lctx.embedding;
|
||||||
|
|
||||||
|
embedding_out.resize(n_embd);
|
||||||
|
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mem_per_token == 0) {
|
if (mem_per_token == 0) {
|
||||||
|
@ -1441,12 +1426,26 @@ struct llama_context * llama_init_from_file(
|
||||||
|
|
||||||
ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||||
|
|
||||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.embedding)) {
|
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only)) {
|
||||||
fprintf(stderr, "%s: failed to load model\n", __func__);
|
fprintf(stderr, "%s: failed to load model\n", __func__);
|
||||||
delete ctx;
|
delete ctx;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reserve memory for context buffers
|
||||||
|
{
|
||||||
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
if (params.logits_all) {
|
||||||
|
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
|
||||||
|
} else {
|
||||||
|
ctx->logits.reserve(hparams.n_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.embedding){
|
||||||
|
ctx->embedding.reserve(hparams.n_embd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1474,9 +1473,8 @@ int llama_eval(
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
int n_tokens,
|
int n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads,
|
int n_threads) {
|
||||||
bool embedding_mode = false) {
|
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) {
|
||||||
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, embedding_mode)) {
|
|
||||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
7
llama.h
7
llama.h
|
@ -53,7 +53,7 @@ extern "C" {
|
||||||
bool f16_kv; // use fp16 for KV cache
|
bool f16_kv; // use fp16 for KV cache
|
||||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||||
bool vocab_only; // only load the vocabulary, no weights
|
bool vocab_only; // only load the vocabulary, no weights
|
||||||
bool embedding; // embedding mode only
|
bool embedding; // embedding mode only
|
||||||
};
|
};
|
||||||
|
|
||||||
LLAMA_API struct llama_context_params llama_context_default_params();
|
LLAMA_API struct llama_context_params llama_context_default_params();
|
||||||
|
@ -85,8 +85,7 @@ extern "C" {
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
int n_tokens,
|
int n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads,
|
int n_threads);
|
||||||
bool embedding_mode);
|
|
||||||
|
|
||||||
// Convert the provided text into tokens.
|
// Convert the provided text into tokens.
|
||||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||||
|
@ -112,7 +111,7 @@ extern "C" {
|
||||||
|
|
||||||
// Get the embeddings for the input
|
// Get the embeddings for the input
|
||||||
// shape: [n_embd] (1-dimensional)
|
// shape: [n_embd] (1-dimensional)
|
||||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx)
|
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||||
|
|
||||||
// Token Id -> String. Uses the vocabulary in the provided context
|
// Token Id -> String. Uses the vocabulary in the provided context
|
||||||
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
|
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
|
||||||
|
|
16
main.cpp
16
main.cpp
|
@ -98,7 +98,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
||||||
int end = start + params.n_ctx - 1;
|
int end = start + params.n_ctx - 1;
|
||||||
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
|
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
|
||||||
auto start_t = std::chrono::high_resolution_clock::now();
|
auto start_t = std::chrono::high_resolution_clock::now();
|
||||||
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads, false)) {
|
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -220,7 +220,7 @@ int main(int argc, char ** argv) {
|
||||||
// TODO: better way to do that
|
// TODO: better way to do that
|
||||||
{
|
{
|
||||||
const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
|
const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
|
||||||
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads, false);
|
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.perplexity) {
|
if (params.perplexity) {
|
||||||
|
@ -302,7 +302,7 @@ int main(int argc, char ** argv) {
|
||||||
#endif
|
#endif
|
||||||
" - Press Return to return control to LLaMa.\n"
|
" - Press Return to return control to LLaMa.\n"
|
||||||
" - If you want to submit another line, end your input in '\\'.\n\n");
|
" - If you want to submit another line, end your input in '\\'.\n\n");
|
||||||
is_interacting = params.interactive_start;
|
is_interacting = params.interactive_start || params.instruct;
|
||||||
}
|
}
|
||||||
|
|
||||||
int input_consumed = 0;
|
int input_consumed = 0;
|
||||||
|
@ -325,23 +325,29 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (params.embedding){
|
if (params.embedding){
|
||||||
embd = embd_inp;
|
embd = embd_inp;
|
||||||
|
|
||||||
if (embd.size() > 0) {
|
if (embd.size() > 0) {
|
||||||
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) {
|
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto embeddings = llama_get_embeddings(ctx);
|
||||||
|
|
||||||
|
// TODO: print / use the embeddings
|
||||||
|
|
||||||
if (params.use_color) {
|
if (params.use_color) {
|
||||||
printf(ANSI_COLOR_RESET);
|
printf(ANSI_COLOR_RESET);
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
while (remaining_tokens > 0 || params.interactive) {
|
while (remaining_tokens > 0 || params.interactive) {
|
||||||
// predict
|
// predict
|
||||||
if (embd.size() > 0) {
|
if (embd.size() > 0) {
|
||||||
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) {
|
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue