Embeddings extraction support
This commit is contained in:
parent
859e70899a
commit
8a3c34bb54
3 changed files with 59 additions and 56 deletions
92
llama.cpp
92
llama.cpp
|
@ -101,9 +101,10 @@ struct llama_context {
|
|||
|
||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||
std::vector<float> logits;
|
||||
bool logits_all = false;
|
||||
|
||||
// input embedding (1-dimensional array: [n_embd])
|
||||
std::vector<float> embedding;
|
||||
bool logits_all = false;
|
||||
};
|
||||
|
||||
struct llama_context_params llama_context_default_params() {
|
||||
|
@ -114,7 +115,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.f16_kv =*/ false,
|
||||
/*.logits_all =*/ false,
|
||||
/*.vocab_only =*/ false,
|
||||
/*.embedding =*/ false,
|
||||
/*.embedding =*/ false,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
@ -130,8 +131,7 @@ static bool llama_model_load(
|
|||
int n_ctx,
|
||||
int n_parts,
|
||||
ggml_type memory_type,
|
||||
bool vocab_only,
|
||||
bool embedding) {
|
||||
bool vocab_only) {
|
||||
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
@ -596,29 +596,11 @@ static bool llama_model_load(
|
|||
fin.close();
|
||||
}
|
||||
|
||||
lctx.logits.reserve(lctx.model.hparams.n_ctx);
|
||||
|
||||
if (embedding){
|
||||
lctx.embedding.reserve(lctx.model.hparams.n_embd);
|
||||
}
|
||||
|
||||
lctx.t_load_us = ggml_time_us() - t_start_us;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Prints the provided embedding vector to stdout
|
||||
// in a neat format
|
||||
void display_embedding(const std::vector<float> & embedding_representation){
|
||||
fprintf(stdout, "\n[\n");
|
||||
for (int j = 0; j < embedding_representation.size()-1 ; j++){
|
||||
fprintf(stdout, "%f, ", embedding_representation[j]);
|
||||
}
|
||||
fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]);
|
||||
fprintf(stdout, "\n]\n");
|
||||
}
|
||||
|
||||
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
|
@ -631,8 +613,7 @@ static bool llama_eval_internal(
|
|||
const llama_token * tokens,
|
||||
const int n_tokens,
|
||||
const int n_past,
|
||||
const int n_threads,
|
||||
const bool embedding_mode = false) {
|
||||
const int n_threads) {
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
const int N = n_tokens;
|
||||
|
@ -810,6 +791,9 @@ static bool llama_eval_internal(
|
|||
inpL = cur;
|
||||
}
|
||||
|
||||
// used at the end to optionally extract the embeddings
|
||||
struct ggml_tensor * embeddings = NULL;
|
||||
|
||||
// norm
|
||||
{
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
|
@ -818,18 +802,8 @@ static bool llama_eval_internal(
|
|||
inpL = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.norm, inpL),
|
||||
inpL);
|
||||
}
|
||||
|
||||
if(embedding_mode){
|
||||
// capture input sentence embedding
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
ggml_graph_compute (ctx0, &gf);
|
||||
std::vector<float> embedding_representation;
|
||||
embedding_representation.resize(n_embd);
|
||||
memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd);
|
||||
display_embedding(embedding_representation);
|
||||
ggml_free(ctx0);
|
||||
return true;
|
||||
embeddings = inpL;
|
||||
}
|
||||
|
||||
// lm_head
|
||||
|
@ -852,15 +826,26 @@ static bool llama_eval_internal(
|
|||
//embd_w.resize(n_vocab*N);
|
||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||
|
||||
auto & logits_out = lctx.logits;
|
||||
// extract logits
|
||||
{
|
||||
auto & logits_out = lctx.logits;
|
||||
|
||||
if (lctx.logits_all) {
|
||||
logits_out.resize(n_vocab * N);
|
||||
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||
} else {
|
||||
// return result for just the last token
|
||||
logits_out.resize(n_vocab);
|
||||
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
if (lctx.logits_all) {
|
||||
logits_out.resize(n_vocab * N);
|
||||
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||
} else {
|
||||
// return result for just the last token
|
||||
logits_out.resize(n_vocab);
|
||||
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
}
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (lctx.embedding.size()) {
|
||||
auto & embedding_out = lctx.embedding;
|
||||
|
||||
embedding_out.resize(n_embd);
|
||||
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
|
||||
}
|
||||
|
||||
if (mem_per_token == 0) {
|
||||
|
@ -1441,12 +1426,26 @@ struct llama_context * llama_init_from_file(
|
|||
|
||||
ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.embedding)) {
|
||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only)) {
|
||||
fprintf(stderr, "%s: failed to load model\n", __func__);
|
||||
delete ctx;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// reserve memory for context buffers
|
||||
{
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
if (params.logits_all) {
|
||||
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
|
||||
} else {
|
||||
ctx->logits.reserve(hparams.n_ctx);
|
||||
}
|
||||
|
||||
if (params.embedding){
|
||||
ctx->embedding.reserve(hparams.n_embd);
|
||||
}
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
@ -1474,9 +1473,8 @@ int llama_eval(
|
|||
const llama_token * tokens,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads,
|
||||
bool embedding_mode = false) {
|
||||
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, embedding_mode)) {
|
||||
int n_threads) {
|
||||
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) {
|
||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
7
llama.h
7
llama.h
|
@ -53,7 +53,7 @@ extern "C" {
|
|||
bool f16_kv; // use fp16 for KV cache
|
||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool embedding; // embedding mode only
|
||||
bool embedding; // embedding mode only
|
||||
};
|
||||
|
||||
LLAMA_API struct llama_context_params llama_context_default_params();
|
||||
|
@ -85,8 +85,7 @@ extern "C" {
|
|||
const llama_token * tokens,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads,
|
||||
bool embedding_mode);
|
||||
int n_threads);
|
||||
|
||||
// Convert the provided text into tokens.
|
||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||
|
@ -112,7 +111,7 @@ extern "C" {
|
|||
|
||||
// Get the embeddings for the input
|
||||
// shape: [n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx)
|
||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
|
||||
|
|
16
main.cpp
16
main.cpp
|
@ -98,7 +98,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||
int end = start + params.n_ctx - 1;
|
||||
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
|
||||
auto start_t = std::chrono::high_resolution_clock::now();
|
||||
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads, false)) {
|
||||
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
@ -220,7 +220,7 @@ int main(int argc, char ** argv) {
|
|||
// TODO: better way to do that
|
||||
{
|
||||
const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
|
||||
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads, false);
|
||||
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
|
||||
}
|
||||
|
||||
if (params.perplexity) {
|
||||
|
@ -302,7 +302,7 @@ int main(int argc, char ** argv) {
|
|||
#endif
|
||||
" - Press Return to return control to LLaMa.\n"
|
||||
" - If you want to submit another line, end your input in '\\'.\n\n");
|
||||
is_interacting = params.interactive_start;
|
||||
is_interacting = params.interactive_start || params.instruct;
|
||||
}
|
||||
|
||||
int input_consumed = 0;
|
||||
|
@ -325,23 +325,29 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (params.embedding){
|
||||
embd = embd_inp;
|
||||
|
||||
if (embd.size() > 0) {
|
||||
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) {
|
||||
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
const auto embeddings = llama_get_embeddings(ctx);
|
||||
|
||||
// TODO: print / use the embeddings
|
||||
|
||||
if (params.use_color) {
|
||||
printf(ANSI_COLOR_RESET);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
while (remaining_tokens > 0 || params.interactive) {
|
||||
// predict
|
||||
if (embd.size() > 0) {
|
||||
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) {
|
||||
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue