parent
7e2b5fb1dd
commit
56ba00b923
9 changed files with 119 additions and 105 deletions
|
@ -611,7 +611,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id);
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
|
||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
||||
|
||||
|
@ -630,12 +630,9 @@ int main(int argc, char ** argv) {
|
|||
while ((int) embd_inp.size() > n_consumed) {
|
||||
embd.push_back(embd_inp[n_consumed]);
|
||||
|
||||
// GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context
|
||||
// Most likely will remove this in the future to avoid exposing "prev"
|
||||
// Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition
|
||||
// penalty will be applied only based on the tokens generated by the model.
|
||||
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
||||
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
|
||||
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||
// for the prompt, we don't apply grammar rules
|
||||
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
|
||||
|
||||
++n_consumed;
|
||||
if ((int) embd.size() >= params.n_batch) {
|
||||
|
@ -666,12 +663,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// if not currently processing queued inputs;
|
||||
if ((int) embd_inp.size() <= n_consumed) {
|
||||
// check for reverse prompt
|
||||
// check for reverse prompt in the last n_prev tokens
|
||||
if (!params.antiprompt.empty()) {
|
||||
std::string last_output;
|
||||
for (auto id : ctx_sampling->prev) {
|
||||
last_output += llama_token_to_piece(ctx, id);
|
||||
}
|
||||
const int n_prev = 32;
|
||||
const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
|
||||
|
||||
is_antiprompt = false;
|
||||
// Check if each of the reverse prompts appears at the end of the output.
|
||||
|
@ -698,7 +693,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// deal with end of text token in interactive mode
|
||||
if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
|
||||
if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
|
||||
LOG("found EOS token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue