with logits_all == true, seek to the last logits vector

This commit is contained in:
Maël Kerbiriou 2023-03-25 14:58:57 +01:00
parent 502a400192
commit ea546b5f8d
2 changed files with 7 additions and 8 deletions

View file

@ -1262,9 +1262,8 @@ static llama_vocab::id llama_sample_top_p_top_k(
auto & rng = lctx.rng;
const auto & vocab = lctx.vocab;
const auto & logits = lctx.logits;
int n_logits = vocab.id_to_token.size();
const int n_logits = vocab.id_to_token.size();
const auto logits = lctx.logits.end() - n_logits;
std::vector<std::pair<double, llama_vocab::id>> logits_id;
logits_id.reserve(n_logits);

View file

@ -250,6 +250,7 @@ int main(int argc, char ** argv) {
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(ctx);
params.n_predict = std::min(params.n_predict, n_ctx - (int) embd_inp.size());
@ -368,9 +369,10 @@ int main(int argc, char ** argv) {
}
while (remaining_tokens > 0 || params.interactive) {
const int n_emb = embd.size();
// predict
if (embd.size() > 0) {
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
if (llama_eval(ctx, embd.data(), n_emb, n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
@ -389,12 +391,10 @@ int main(int argc, char ** argv) {
llama_token id = 0;
{
auto logits = llama_get_logits(ctx);
if (params.ignore_eos) {
// Logits after the last token
auto logits = llama_get_logits(ctx) + (n_emb - 1) * n_vocab;
// set the logit of the eos token to zero to avoid sampling it
//logits[logits.size() - n_vocab + EOS_TOKEN_ID] = 0;
// TODO: this does not work of params.logits_all == true
assert(params.perplexity == false);
logits[llama_token_eos()] = 0;
}