Merge pull request #18 from jxy/limit_tokens

Compute remaining tokens along the way and exit if over
This commit is contained in:
Kevin Kwok 2023-03-17 09:44:12 -07:00 committed by GitHub
commit 197df5f096
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -919,7 +919,8 @@ int main(int argc, char ** argv) {
" - If you want to submit another line, end your input in '\\'.\n");
}
int remaining_tokens = params.n_predict;
// we may want to slide the input window along with the context, but for now we restrict to the context length
int remaining_tokens = model.hparams.n_ctx - embd_inp.size();
int input_consumed = 0;
bool input_noecho = true;
@ -935,7 +936,7 @@ int main(int argc, char ** argv) {
while (true) {
while (remaining_tokens > 0) {
// predict
if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us();
@ -980,7 +981,7 @@ int main(int argc, char ** argv) {
input_noecho = false;
// decrement remaining sampling budget
// --remaining_tokens;
--remaining_tokens;
} else {
// some user input remains from prompt or interaction, forward it to processing
while (embd_inp.size() > input_consumed) {
@ -1054,6 +1055,8 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
embd_inp.insert(embd_inp.end(), response_inp.begin(), response_inp.end());
remaining_tokens -= prompt_inp.size() + line_inp.size() + response_inp.size();
input_noecho = true; // do not echo this again
}