Refactor special tokens tokenization

This commit is contained in:
Igor Pissolati 2023-08-08 12:46:18 -03:00
parent ada6cce40f
commit 465cadd44c

View file

@ -279,14 +279,9 @@ struct llama_vocab {
std::vector<token_score> id_to_token;
std::unordered_map<token, id> special_token_to_id;
size_t max_special_token_length = 0;
void add_special_token(const token & word, id token_id) {
special_token_to_id[word] = token_id;
if (max_special_token_length < word.size()) {
max_special_token_length = word.size();
}
}
};
@ -2088,38 +2083,6 @@ private:
llama_sp_bigram::queue work_queue_;
};
static std::vector<size_t> llama_split_special_tokens(const llama_vocab & vocab, const std::string & text) {
std::vector<size_t> offsets{0};
size_t start = 0;
while (start < text.size()) {
size_t max_end = start;
const std::string * max_delimiter = nullptr;
for (const auto & mit : vocab.special_token_to_id) {
const std::string & delimiter = mit.first;
size_t end = start + delimiter.size();
if (end <= text.size() && text.compare(start, delimiter.size(), delimiter) == 0) {
if (max_delimiter == nullptr || delimiter.size() > max_delimiter->size()) {
max_end = end;
max_delimiter = &delimiter;
}
}
}
if (max_delimiter != nullptr) {
offsets.push_back(start);
offsets.push_back(max_end);
start = max_end;
} else {
start++;
}
}
offsets.push_back(text.size());
return offsets;
}
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
llama_tokenizer tokenizer(vocab);
std::vector<llama_vocab::id> output;
@ -2137,27 +2100,40 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
return output;
}
std::vector<size_t> offsets = llama_split_special_tokens(vocab, text);
size_t start = 0;
for (size_t end : offsets) {
if (start >= end) {
continue;
}
size_t delim_start = 0;
size_t last_delim_end = 0;
const char *part = text.c_str() + start;
size_t part_len = end - start;
if (vocab.max_special_token_length < part_len) {
tokenizer.tokenize(part, part_len, output);
} else {
auto token_it = vocab.special_token_to_id.find(std::string(part, part_len));
if (token_it != vocab.special_token_to_id.end()) {
output.push_back(token_it->second);
} else {
tokenizer.tokenize(part, part_len, output);
while (delim_start < text.size()) {
size_t delim_end = 0;
llama_vocab::id token_id = -1;
for (const auto & mit : vocab.special_token_to_id) {
const std::string & delimiter = mit.first;
size_t end = delim_start + delimiter.size();
if (end <= text.size() && text.compare(delim_start, delimiter.size(), delimiter) == 0) {
if (token_id == -1 || end > delim_end) {
token_id = mit.second;
delim_end = end;
}
}
}
start = end;
if (token_id != -1) {
if (last_delim_end < delim_start) {
tokenizer.tokenize(text.c_str() + last_delim_end, delim_start - last_delim_end, output);
}
output.push_back(token_id);
delim_start = delim_end;
last_delim_end = delim_end;
} else {
delim_start++;
}
}
if (last_delim_end < text.size()) {
tokenizer.tokenize(text.c_str() + last_delim_end, text.size() - last_delim_end, output);
}
return output;
}