examples : do not eval prompt 2 times (close #3348)
This commit is contained in:
parent
a207561503
commit
2b8830af71
2 changed files with 24 additions and 21 deletions
|
@ -110,16 +110,10 @@ int main(int argc, char ** argv) {
|
|||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
while (n_cur <= n_len) {
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// sample the next token
|
||||
{
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
auto logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
@ -158,6 +152,12 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
n_cur += 1;
|
||||
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue