with logits_all == true, seek to the last logits vector
This commit is contained in:
parent
502a400192
commit
ea546b5f8d
2 changed files with 7 additions and 8 deletions
|
@ -1262,9 +1262,8 @@ static llama_vocab::id llama_sample_top_p_top_k(
|
|||
auto & rng = lctx.rng;
|
||||
|
||||
const auto & vocab = lctx.vocab;
|
||||
const auto & logits = lctx.logits;
|
||||
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
const int n_logits = vocab.id_to_token.size();
|
||||
const auto logits = lctx.logits.end() - n_logits;
|
||||
|
||||
std::vector<std::pair<double, llama_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
|
10
main.cpp
10
main.cpp
|
@ -250,6 +250,7 @@ int main(int argc, char ** argv) {
|
|||
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_vocab = llama_n_vocab(ctx);
|
||||
|
||||
params.n_predict = std::min(params.n_predict, n_ctx - (int) embd_inp.size());
|
||||
|
||||
|
@ -368,9 +369,10 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
while (remaining_tokens > 0 || params.interactive) {
|
||||
const int n_emb = embd.size();
|
||||
// predict
|
||||
if (embd.size() > 0) {
|
||||
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
|
||||
if (llama_eval(ctx, embd.data(), n_emb, n_past, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
@ -389,12 +391,10 @@ int main(int argc, char ** argv) {
|
|||
llama_token id = 0;
|
||||
|
||||
{
|
||||
auto logits = llama_get_logits(ctx);
|
||||
|
||||
if (params.ignore_eos) {
|
||||
// Logits after the last token
|
||||
auto logits = llama_get_logits(ctx) + (n_emb - 1) * n_vocab;
|
||||
// set the logit of the eos token to zero to avoid sampling it
|
||||
//logits[logits.size() - n_vocab + EOS_TOKEN_ID] = 0;
|
||||
// TODO: this does not work of params.logits_all == true
|
||||
assert(params.perplexity == false);
|
||||
logits[llama_token_eos()] = 0;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue