Refactor special tokens tokenization
This commit is contained in:
parent
ada6cce40f
commit
465cadd44c
1 changed files with 30 additions and 54 deletions
84
llama.cpp
84
llama.cpp
|
@ -279,14 +279,9 @@ struct llama_vocab {
|
||||||
std::vector<token_score> id_to_token;
|
std::vector<token_score> id_to_token;
|
||||||
|
|
||||||
std::unordered_map<token, id> special_token_to_id;
|
std::unordered_map<token, id> special_token_to_id;
|
||||||
size_t max_special_token_length = 0;
|
|
||||||
|
|
||||||
void add_special_token(const token & word, id token_id) {
|
void add_special_token(const token & word, id token_id) {
|
||||||
special_token_to_id[word] = token_id;
|
special_token_to_id[word] = token_id;
|
||||||
|
|
||||||
if (max_special_token_length < word.size()) {
|
|
||||||
max_special_token_length = word.size();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2088,38 +2083,6 @@ private:
|
||||||
llama_sp_bigram::queue work_queue_;
|
llama_sp_bigram::queue work_queue_;
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::vector<size_t> llama_split_special_tokens(const llama_vocab & vocab, const std::string & text) {
|
|
||||||
std::vector<size_t> offsets{0};
|
|
||||||
size_t start = 0;
|
|
||||||
|
|
||||||
while (start < text.size()) {
|
|
||||||
size_t max_end = start;
|
|
||||||
const std::string * max_delimiter = nullptr;
|
|
||||||
|
|
||||||
for (const auto & mit : vocab.special_token_to_id) {
|
|
||||||
const std::string & delimiter = mit.first;
|
|
||||||
size_t end = start + delimiter.size();
|
|
||||||
if (end <= text.size() && text.compare(start, delimiter.size(), delimiter) == 0) {
|
|
||||||
if (max_delimiter == nullptr || delimiter.size() > max_delimiter->size()) {
|
|
||||||
max_end = end;
|
|
||||||
max_delimiter = &delimiter;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (max_delimiter != nullptr) {
|
|
||||||
offsets.push_back(start);
|
|
||||||
offsets.push_back(max_end);
|
|
||||||
start = max_end;
|
|
||||||
} else {
|
|
||||||
start++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
offsets.push_back(text.size());
|
|
||||||
return offsets;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
|
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
|
||||||
llama_tokenizer tokenizer(vocab);
|
llama_tokenizer tokenizer(vocab);
|
||||||
std::vector<llama_vocab::id> output;
|
std::vector<llama_vocab::id> output;
|
||||||
|
@ -2137,27 +2100,40 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> offsets = llama_split_special_tokens(vocab, text);
|
size_t delim_start = 0;
|
||||||
size_t start = 0;
|
size_t last_delim_end = 0;
|
||||||
for (size_t end : offsets) {
|
|
||||||
if (start >= end) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char *part = text.c_str() + start;
|
while (delim_start < text.size()) {
|
||||||
size_t part_len = end - start;
|
size_t delim_end = 0;
|
||||||
if (vocab.max_special_token_length < part_len) {
|
llama_vocab::id token_id = -1;
|
||||||
tokenizer.tokenize(part, part_len, output);
|
|
||||||
} else {
|
for (const auto & mit : vocab.special_token_to_id) {
|
||||||
auto token_it = vocab.special_token_to_id.find(std::string(part, part_len));
|
const std::string & delimiter = mit.first;
|
||||||
if (token_it != vocab.special_token_to_id.end()) {
|
size_t end = delim_start + delimiter.size();
|
||||||
output.push_back(token_it->second);
|
if (end <= text.size() && text.compare(delim_start, delimiter.size(), delimiter) == 0) {
|
||||||
} else {
|
if (token_id == -1 || end > delim_end) {
|
||||||
tokenizer.tokenize(part, part_len, output);
|
token_id = mit.second;
|
||||||
|
delim_end = end;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
start = end;
|
|
||||||
|
if (token_id != -1) {
|
||||||
|
if (last_delim_end < delim_start) {
|
||||||
|
tokenizer.tokenize(text.c_str() + last_delim_end, delim_start - last_delim_end, output);
|
||||||
|
}
|
||||||
|
output.push_back(token_id);
|
||||||
|
delim_start = delim_end;
|
||||||
|
last_delim_end = delim_end;
|
||||||
|
} else {
|
||||||
|
delim_start++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (last_delim_end < text.size()) {
|
||||||
|
tokenizer.tokenize(text.c_str() + last_delim_end, text.size() - last_delim_end, output);
|
||||||
|
}
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue