Compute remaining tokens along the way and exit if over

This commit is contained in:
Xiao-Yong Jin 2023-03-17 00:20:24 -05:00
parent 235a4115df
commit 5be098f51e

View file

@ -919,7 +919,8 @@ int main(int argc, char ** argv) {
" - If you want to submit another line, end your input in '\\'.\n");
}
int remaining_tokens = params.n_predict;
// we may want to slide the input window along with the context, but for now we restrict to the context length
int remaining_tokens = model.hparams.n_ctx - embd_inp.size();
int input_consumed = 0;
bool input_noecho = true;
@ -935,7 +936,7 @@ int main(int argc, char ** argv) {
while (true) {
while (remaining_tokens > 0) {
// predict
if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us();
@ -980,7 +981,7 @@ int main(int argc, char ** argv) {
input_noecho = false;
// decrement remaining sampling budget
// --remaining_tokens;
--remaining_tokens;
} else {
// some user input remains from prompt or interaction, forward it to processing
while (embd_inp.size() > input_consumed) {
@ -1054,6 +1055,8 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
embd_inp.insert(embd_inp.end(), response_inp.begin(), response_inp.end());
remaining_tokens -= prompt_inp.size() + line_inp.size() + response_inp.size();
input_noecho = true; // do not echo this again
}