Truncate prompt if longer than context + n_predict
This commit is contained in:
parent
b9bd1d0141
commit
1133eea479
1 changed files with 4 additions and 2 deletions
6
main.cpp
6
main.cpp
|
@ -782,8 +782,10 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
|
std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
|
||||||
|
if (embd_inp.size() + params.n_predict > model.hparams.n_ctx) {
|
||||||
params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
|
int offset = embd_inp.size() - model.hparams.n_ctx + params.n_predict;
|
||||||
|
embd_inp = std::vector<gpt_vocab::id>(embd_inp.begin() + offset, embd_inp.end());
|
||||||
|
}
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue