Embeddings extraction support

This commit is contained in:
Georgi Gerganov 2023-03-23 22:02:14 +02:00
parent 859e70899a
commit 8a3c34bb54
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 59 additions and 56 deletions

View file

@ -101,9 +101,10 @@ struct llama_context {
// decode output (2-dimensional array: [n_tokens][n_vocab]) // decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits; std::vector<float> logits;
bool logits_all = false;
// input embedding (1-dimensional array: [n_embd]) // input embedding (1-dimensional array: [n_embd])
std::vector<float> embedding; std::vector<float> embedding;
bool logits_all = false;
}; };
struct llama_context_params llama_context_default_params() { struct llama_context_params llama_context_default_params() {
@ -114,7 +115,7 @@ struct llama_context_params llama_context_default_params() {
/*.f16_kv =*/ false, /*.f16_kv =*/ false,
/*.logits_all =*/ false, /*.logits_all =*/ false,
/*.vocab_only =*/ false, /*.vocab_only =*/ false,
/*.embedding =*/ false, /*.embedding =*/ false,
}; };
return result; return result;
@ -130,8 +131,7 @@ static bool llama_model_load(
int n_ctx, int n_ctx,
int n_parts, int n_parts,
ggml_type memory_type, ggml_type memory_type,
bool vocab_only, bool vocab_only) {
bool embedding) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
@ -596,29 +596,11 @@ static bool llama_model_load(
fin.close(); fin.close();
} }
lctx.logits.reserve(lctx.model.hparams.n_ctx);
if (embedding){
lctx.embedding.reserve(lctx.model.hparams.n_embd);
}
lctx.t_load_us = ggml_time_us() - t_start_us; lctx.t_load_us = ggml_time_us() - t_start_us;
return true; return true;
} }
// Prints the provided embedding vector to stdout
// in a neat format
void display_embedding(const std::vector<float> & embedding_representation){
fprintf(stdout, "\n[\n");
for (int j = 0; j < embedding_representation.size()-1 ; j++){
fprintf(stdout, "%f, ", embedding_representation[j]);
}
fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]);
fprintf(stdout, "\n]\n");
}
// evaluate the transformer // evaluate the transformer
// //
// - lctx: llama context // - lctx: llama context
@ -631,8 +613,7 @@ static bool llama_eval_internal(
const llama_token * tokens, const llama_token * tokens,
const int n_tokens, const int n_tokens,
const int n_past, const int n_past,
const int n_threads, const int n_threads) {
const bool embedding_mode = false) {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
const int N = n_tokens; const int N = n_tokens;
@ -810,6 +791,9 @@ static bool llama_eval_internal(
inpL = cur; inpL = cur;
} }
// used at the end to optionally extract the embeddings
struct ggml_tensor * embeddings = NULL;
// norm // norm
{ {
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL);
@ -818,18 +802,8 @@ static bool llama_eval_internal(
inpL = ggml_mul(ctx0, inpL = ggml_mul(ctx0,
ggml_repeat(ctx0, model.norm, inpL), ggml_repeat(ctx0, model.norm, inpL),
inpL); inpL);
}
if(embedding_mode){ embeddings = inpL;
// capture input sentence embedding
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);
std::vector<float> embedding_representation;
embedding_representation.resize(n_embd);
memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd);
display_embedding(embedding_representation);
ggml_free(ctx0);
return true;
} }
// lm_head // lm_head
@ -852,15 +826,26 @@ static bool llama_eval_internal(
//embd_w.resize(n_vocab*N); //embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
auto & logits_out = lctx.logits; // extract logits
{
auto & logits_out = lctx.logits;
if (lctx.logits_all) { if (lctx.logits_all) {
logits_out.resize(n_vocab * N); logits_out.resize(n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
} else { } else {
// return result for just the last token // return result for just the last token
logits_out.resize(n_vocab); logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
}
}
// extract embeddings
if (lctx.embedding.size()) {
auto & embedding_out = lctx.embedding;
embedding_out.resize(n_embd);
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
} }
if (mem_per_token == 0) { if (mem_per_token == 0) {
@ -1441,12 +1426,26 @@ struct llama_context * llama_init_from_file(
ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.embedding)) { if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only)) {
fprintf(stderr, "%s: failed to load model\n", __func__); fprintf(stderr, "%s: failed to load model\n", __func__);
delete ctx; delete ctx;
return nullptr; return nullptr;
} }
// reserve memory for context buffers
{
const auto & hparams = ctx->model.hparams;
if (params.logits_all) {
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
} else {
ctx->logits.reserve(hparams.n_ctx);
}
if (params.embedding){
ctx->embedding.reserve(hparams.n_embd);
}
}
return ctx; return ctx;
} }
@ -1474,9 +1473,8 @@ int llama_eval(
const llama_token * tokens, const llama_token * tokens,
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads, int n_threads) {
bool embedding_mode = false) { if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) {
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, embedding_mode)) {
fprintf(stderr, "%s: failed to eval\n", __func__); fprintf(stderr, "%s: failed to eval\n", __func__);
return 1; return 1;
} }

View file

@ -53,7 +53,7 @@ extern "C" {
bool f16_kv; // use fp16 for KV cache bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights bool vocab_only; // only load the vocabulary, no weights
bool embedding; // embedding mode only bool embedding; // embedding mode only
}; };
LLAMA_API struct llama_context_params llama_context_default_params(); LLAMA_API struct llama_context_params llama_context_default_params();
@ -85,8 +85,7 @@ extern "C" {
const llama_token * tokens, const llama_token * tokens,
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads, int n_threads);
bool embedding_mode);
// Convert the provided text into tokens. // Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens. // The tokens pointer must be large enough to hold the resulting tokens.
@ -112,7 +111,7 @@ extern "C" {
// Get the embeddings for the input // Get the embeddings for the input
// shape: [n_embd] (1-dimensional) // shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context // Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token); LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);

View file

@ -98,7 +98,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
int end = start + params.n_ctx - 1; int end = start + params.n_ctx - 1;
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end); std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
auto start_t = std::chrono::high_resolution_clock::now(); auto start_t = std::chrono::high_resolution_clock::now();
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads, false)) { if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return; return;
} }
@ -220,7 +220,7 @@ int main(int argc, char ** argv) {
// TODO: better way to do that // TODO: better way to do that
{ {
const std::vector<llama_token> tmp = { 0, 1, 2, 3 }; const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads, false); llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
} }
if (params.perplexity) { if (params.perplexity) {
@ -302,7 +302,7 @@ int main(int argc, char ** argv) {
#endif #endif
" - Press Return to return control to LLaMa.\n" " - Press Return to return control to LLaMa.\n"
" - If you want to submit another line, end your input in '\\'.\n\n"); " - If you want to submit another line, end your input in '\\'.\n\n");
is_interacting = params.interactive_start; is_interacting = params.interactive_start || params.instruct;
} }
int input_consumed = 0; int input_consumed = 0;
@ -325,23 +325,29 @@ int main(int argc, char ** argv) {
if (params.embedding){ if (params.embedding){
embd = embd_inp; embd = embd_inp;
if (embd.size() > 0) { if (embd.size() > 0) {
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) { if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return 1; return 1;
} }
} }
const auto embeddings = llama_get_embeddings(ctx);
// TODO: print / use the embeddings
if (params.use_color) { if (params.use_color) {
printf(ANSI_COLOR_RESET); printf(ANSI_COLOR_RESET);
} }
return 0; return 0;
} }
while (remaining_tokens > 0 || params.interactive) { while (remaining_tokens > 0 || params.interactive) {
// predict // predict
if (embd.size() > 0) { if (embd.size() > 0) {
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) { if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return 1; return 1;
} }