From ea546b5f8d8b824747baddfdaa68c5b94cbc39c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ma=C3=ABl=20Kerbiriou?= Date: Sat, 25 Mar 2023 14:58:57 +0100 Subject: [PATCH] with logits_all == true, seek to the last logits vector --- llama.cpp | 5 ++--- main.cpp | 10 +++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/llama.cpp b/llama.cpp index 0015edec1..f10a8141e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1262,9 +1262,8 @@ static llama_vocab::id llama_sample_top_p_top_k( auto & rng = lctx.rng; const auto & vocab = lctx.vocab; - const auto & logits = lctx.logits; - - int n_logits = vocab.id_to_token.size(); + const int n_logits = vocab.id_to_token.size(); + const auto logits = lctx.logits.end() - n_logits; std::vector> logits_id; logits_id.reserve(n_logits); diff --git a/main.cpp b/main.cpp index 77260bb71..c0b0c55db 100644 --- a/main.cpp +++ b/main.cpp @@ -250,6 +250,7 @@ int main(int argc, char ** argv) { auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); const int n_ctx = llama_n_ctx(ctx); + const int n_vocab = llama_n_vocab(ctx); params.n_predict = std::min(params.n_predict, n_ctx - (int) embd_inp.size()); @@ -368,9 +369,10 @@ int main(int argc, char ** argv) { } while (remaining_tokens > 0 || params.interactive) { + const int n_emb = embd.size(); // predict if (embd.size() > 0) { - if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { + if (llama_eval(ctx, embd.data(), n_emb, n_past, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } @@ -389,12 +391,10 @@ int main(int argc, char ** argv) { llama_token id = 0; { - auto logits = llama_get_logits(ctx); - if (params.ignore_eos) { + // Logits after the last token + auto logits = llama_get_logits(ctx) + (n_emb - 1) * n_vocab; // set the logit of the eos token to zero to avoid sampling it - //logits[logits.size() - n_vocab + EOS_TOKEN_ID] = 0; - // TODO: this does not work of params.logits_all == true assert(params.perplexity == false); logits[llama_token_eos()] = 0; }