Refactor special tokens tokenization
This commit is contained in:
parent
ada6cce40f
commit
465cadd44c
1 changed files with 30 additions and 54 deletions
84
llama.cpp
84
llama.cpp
|
@ -279,14 +279,9 @@ struct llama_vocab {
|
|||
std::vector<token_score> id_to_token;
|
||||
|
||||
std::unordered_map<token, id> special_token_to_id;
|
||||
size_t max_special_token_length = 0;
|
||||
|
||||
void add_special_token(const token & word, id token_id) {
|
||||
special_token_to_id[word] = token_id;
|
||||
|
||||
if (max_special_token_length < word.size()) {
|
||||
max_special_token_length = word.size();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2088,38 +2083,6 @@ private:
|
|||
llama_sp_bigram::queue work_queue_;
|
||||
};
|
||||
|
||||
static std::vector<size_t> llama_split_special_tokens(const llama_vocab & vocab, const std::string & text) {
|
||||
std::vector<size_t> offsets{0};
|
||||
size_t start = 0;
|
||||
|
||||
while (start < text.size()) {
|
||||
size_t max_end = start;
|
||||
const std::string * max_delimiter = nullptr;
|
||||
|
||||
for (const auto & mit : vocab.special_token_to_id) {
|
||||
const std::string & delimiter = mit.first;
|
||||
size_t end = start + delimiter.size();
|
||||
if (end <= text.size() && text.compare(start, delimiter.size(), delimiter) == 0) {
|
||||
if (max_delimiter == nullptr || delimiter.size() > max_delimiter->size()) {
|
||||
max_end = end;
|
||||
max_delimiter = &delimiter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (max_delimiter != nullptr) {
|
||||
offsets.push_back(start);
|
||||
offsets.push_back(max_end);
|
||||
start = max_end;
|
||||
} else {
|
||||
start++;
|
||||
}
|
||||
}
|
||||
|
||||
offsets.push_back(text.size());
|
||||
return offsets;
|
||||
}
|
||||
|
||||
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
|
||||
llama_tokenizer tokenizer(vocab);
|
||||
std::vector<llama_vocab::id> output;
|
||||
|
@ -2137,27 +2100,40 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
|
|||
return output;
|
||||
}
|
||||
|
||||
std::vector<size_t> offsets = llama_split_special_tokens(vocab, text);
|
||||
size_t start = 0;
|
||||
for (size_t end : offsets) {
|
||||
if (start >= end) {
|
||||
continue;
|
||||
}
|
||||
size_t delim_start = 0;
|
||||
size_t last_delim_end = 0;
|
||||
|
||||
const char *part = text.c_str() + start;
|
||||
size_t part_len = end - start;
|
||||
if (vocab.max_special_token_length < part_len) {
|
||||
tokenizer.tokenize(part, part_len, output);
|
||||
} else {
|
||||
auto token_it = vocab.special_token_to_id.find(std::string(part, part_len));
|
||||
if (token_it != vocab.special_token_to_id.end()) {
|
||||
output.push_back(token_it->second);
|
||||
} else {
|
||||
tokenizer.tokenize(part, part_len, output);
|
||||
while (delim_start < text.size()) {
|
||||
size_t delim_end = 0;
|
||||
llama_vocab::id token_id = -1;
|
||||
|
||||
for (const auto & mit : vocab.special_token_to_id) {
|
||||
const std::string & delimiter = mit.first;
|
||||
size_t end = delim_start + delimiter.size();
|
||||
if (end <= text.size() && text.compare(delim_start, delimiter.size(), delimiter) == 0) {
|
||||
if (token_id == -1 || end > delim_end) {
|
||||
token_id = mit.second;
|
||||
delim_end = end;
|
||||
}
|
||||
}
|
||||
}
|
||||
start = end;
|
||||
|
||||
if (token_id != -1) {
|
||||
if (last_delim_end < delim_start) {
|
||||
tokenizer.tokenize(text.c_str() + last_delim_end, delim_start - last_delim_end, output);
|
||||
}
|
||||
output.push_back(token_id);
|
||||
delim_start = delim_end;
|
||||
last_delim_end = delim_end;
|
||||
} else {
|
||||
delim_start++;
|
||||
}
|
||||
}
|
||||
|
||||
if (last_delim_end < text.size()) {
|
||||
tokenizer.tokenize(text.c_str() + last_delim_end, text.size() - last_delim_end, output);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue