diff --git a/llama.cpp b/llama.cpp index 3b6d23eac..44104be66 100644 --- a/llama.cpp +++ b/llama.cpp @@ -279,14 +279,9 @@ struct llama_vocab { std::vector id_to_token; std::unordered_map special_token_to_id; - size_t max_special_token_length = 0; void add_special_token(const token & word, id token_id) { special_token_to_id[word] = token_id; - - if (max_special_token_length < word.size()) { - max_special_token_length = word.size(); - } } }; @@ -2088,38 +2083,6 @@ private: llama_sp_bigram::queue work_queue_; }; -static std::vector llama_split_special_tokens(const llama_vocab & vocab, const std::string & text) { - std::vector offsets{0}; - size_t start = 0; - - while (start < text.size()) { - size_t max_end = start; - const std::string * max_delimiter = nullptr; - - for (const auto & mit : vocab.special_token_to_id) { - const std::string & delimiter = mit.first; - size_t end = start + delimiter.size(); - if (end <= text.size() && text.compare(start, delimiter.size(), delimiter) == 0) { - if (max_delimiter == nullptr || delimiter.size() > max_delimiter->size()) { - max_end = end; - max_delimiter = &delimiter; - } - } - } - - if (max_delimiter != nullptr) { - offsets.push_back(start); - offsets.push_back(max_end); - start = max_end; - } else { - start++; - } - } - - offsets.push_back(text.size()); - return offsets; -} - static std::vector llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) { llama_tokenizer tokenizer(vocab); std::vector output; @@ -2137,27 +2100,40 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co return output; } - std::vector offsets = llama_split_special_tokens(vocab, text); - size_t start = 0; - for (size_t end : offsets) { - if (start >= end) { - continue; - } + size_t delim_start = 0; + size_t last_delim_end = 0; - const char *part = text.c_str() + start; - size_t part_len = end - start; - if (vocab.max_special_token_length < part_len) { - tokenizer.tokenize(part, part_len, output); - } else { - auto token_it = vocab.special_token_to_id.find(std::string(part, part_len)); - if (token_it != vocab.special_token_to_id.end()) { - output.push_back(token_it->second); - } else { - tokenizer.tokenize(part, part_len, output); + while (delim_start < text.size()) { + size_t delim_end = 0; + llama_vocab::id token_id = -1; + + for (const auto & mit : vocab.special_token_to_id) { + const std::string & delimiter = mit.first; + size_t end = delim_start + delimiter.size(); + if (end <= text.size() && text.compare(delim_start, delimiter.size(), delimiter) == 0) { + if (token_id == -1 || end > delim_end) { + token_id = mit.second; + delim_end = end; + } } } - start = end; + + if (token_id != -1) { + if (last_delim_end < delim_start) { + tokenizer.tokenize(text.c_str() + last_delim_end, delim_start - last_delim_end, output); + } + output.push_back(token_id); + delim_start = delim_end; + last_delim_end = delim_end; + } else { + delim_start++; + } } + + if (last_delim_end < text.size()) { + tokenizer.tokenize(text.c_str() + last_delim_end, text.size() - last_delim_end, output); + } + return output; }