fix: set ignore_merges only for llama-3

This commit is contained in:
Haoxiang Fei 2024-05-10 16:12:20 +08:00
parent c7614930f3
commit 8a51d3b12d
2 changed files with 19 additions and 4 deletions

View file

@ -12200,12 +12200,14 @@ struct llm_tokenizer_bpe {
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
int final_prev_index = -1;
int ignore_merges = false;
std::vector<std::string> word_collection;
switch (vocab.type) {
case LLAMA_VOCAB_TYPE_BPE:
switch (vocab.type_pre) {
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
ignore_merges = true;
case LLAMA_VOCAB_PRE_TYPE_DBRX:
word_collection = unicode_regex_split(text, {
// original regex from tokenizer.json
@ -12292,7 +12294,7 @@ struct llm_tokenizer_bpe {
symbols_final.clear();
for (auto & word : word_collection) {
if (vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
llm_symbol sym;
sym.text = word.c_str();
sym.n = word.size();

View file

@ -13,15 +13,27 @@
#include <vector>
int main(int argc, char **argv) {
if (argc < 2) {
fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
if (argc < 2 || argc > 3) {
fprintf(stderr, "Usage: %s <vocab-file> [--ignore-merges]\n", argv[0]);
return 1;
}
const std::string fname = argv[1];
bool ignore_merges = false;
if (argc == 3) {
if (std::strcmp(argv[2], "ignore-merges") != 0) {
fprintf(stderr, "Usage: %s <vocab-file> [--ignore-merges]\n", argv[0]);
return 1;
}
ignore_merges = true;
}
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
if (ignore_merges) {
fprintf(stderr, "%s : ignoring merges for tokens inside vocab\n", __func__);
}
llama_model * model;
llama_context * ctx;
@ -66,7 +78,7 @@ int main(int argc, char **argv) {
try {
auto cps = unicode_cpts_from_utf8(str);
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true);
if (tokens.size() > 1) {
if (ignore_merges && tokens.size() > 1) {
fprintf(stderr,
"%s : error: token %d detokenizes to '%s'(%zu) but "
"tokenization of this to multiple tokens: [",
@ -76,6 +88,7 @@ int main(int argc, char **argv) {
fprintf(stderr, ", %d", tokens[i]);
}
fprintf(stderr, "]\n");
return 2;
}
std::string check = llama_detokenize_bpe(ctx, tokens);
if (check != str) {