Add SPM infill support (#8016)

* add --spm-infill option

* support --spm-infill

* support --spm-infill
This commit is contained in:
Sigbjørn Skjæret 2024-06-28 12:53:43 +02:00 committed by GitHub
parent b851b3fba0
commit 38373cfbab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 34 additions and 16 deletions

View file

@ -15,6 +15,7 @@ In this section, we cover the most commonly used options for running the `infill
- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
- `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text.
- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference.
- `--spm-infill`: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
## Input Prompts

View file

@ -210,6 +210,7 @@ int main(int argc, char ** argv) {
suff_rm_leading_spc = false;
}
std::vector<llama_token> embd_inp;
std::vector<llama_token> embd_end;
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
const int space_token = 29871;
@ -217,12 +218,13 @@ int main(int argc, char ** argv) {
inp_sfx.erase(inp_sfx.begin());
}
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
if (add_bos) {
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
}
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
const llama_token middle_token = llama_token_middle(model);
if (middle_token >= 0) {
@ -526,14 +528,14 @@ int main(int argc, char ** argv) {
inp_sfx.erase(inp_sfx.begin());
}
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
if (add_bos) {
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
}
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
const llama_token middle_token = llama_token_middle(model);
if (middle_token >= 0) {
embd_inp.push_back(middle_token);
}