Add embedding mode with arg flag. Currently working (#282)

* working but ugly

* add arg flag, not working on embedding mode

* typo

* Working! Thanks to @nullhook

* make params argument instead of hardcoded boolean. remove useless time check

* start doing the instructions but not finished. This probably doesnt compile

* Embeddings extraction support

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Luciano 2023-03-24 08:05:13 -07:00 committed by GitHub
parent b6b268d441
commit 8d4a855c24
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 82 additions and 10 deletions

View file

@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
ctx = llama_init_from_file(params.model.c_str(), lparams);
@ -292,6 +293,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
int last_n_size = params.repeat_last_n;
std::vector<llama_token> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
@ -324,6 +326,27 @@ int main(int argc, char ** argv) {
// the first thing we will do is to output the prompt, so set color accordingly
set_console_state(CONSOLE_STATE_PROMPT);
if (params.embedding){
embd = embd_inp;
if (embd.size() > 0) {
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
}
const auto embeddings = llama_get_embeddings(ctx);
// TODO: print / use the embeddings
if (params.use_color) {
printf(ANSI_COLOR_RESET);
}
return 0;
}
while (remaining_tokens > 0 || params.interactive) {
// predict
if (embd.size() > 0) {