From 61a98bc30aca7d7a839539577be4931250c881e3 Mon Sep 17 00:00:00 2001 From: Igor Pissolati Date: Sun, 18 Jun 2023 20:11:01 -0300 Subject: [PATCH] Improve support for special tokens --- convert.py | 89 +++++++++++++++++++--------- llama-util.h | 164 +++++++++++++++++++++++++++++++++++++++++++++++++++ llama.cpp | 76 +++++++++++++++++++++--- llama.h | 4 +- 4 files changed, 297 insertions(+), 36 deletions(-) diff --git a/convert.py b/convert.py index f3bf17980..8bc06120d 100644 --- a/convert.py +++ b/convert.py @@ -142,6 +142,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int: @dataclass class Params: n_vocab: int + n_vocab_sp:int n_embd: int n_mult: int n_head: int @@ -169,6 +170,7 @@ class Params: return Params( n_vocab = n_vocab, + n_vocab_sp= n_vocab, n_embd = n_embd, n_mult = 256, n_head = n_head, @@ -191,6 +193,7 @@ class Params: return Params( n_vocab = n_vocab, + n_vocab_sp= n_vocab, n_embd = n_embd, n_mult = n_mult, n_head = n_head, @@ -215,6 +218,7 @@ class Params: return Params( n_vocab = n_vocab, + n_vocab_sp= n_vocab n_embd = n_embd, n_mult = n_mult, n_head = n_head, @@ -239,7 +243,7 @@ class Params: class SentencePieceVocab: - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path], vocabtype: Optional[str]) -> None: self.vocabtype = vocabtype if self.vocabtype == "bpe": self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read()) @@ -264,35 +268,46 @@ class SentencePieceVocab: self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens + special_tokens: Dict[str, Dict[str, Any]] + if fname_special_tokens is not None: + special_tokens = json.load(open(fname_special_tokens)) + else: + special_tokens = {} + token_name_to_id = {"unk_token": self.sentencepiece_tokenizer.unk_id(), "bos_token": self.sentencepiece_tokenizer.bos_id(), "eos_token": self.sentencepiece_tokenizer.eos_id(), "pad_token": self.sentencepiece_tokenizer.pad_id()} + self.special_tokens_map = {token_name_to_id[token_name]: info["content"] if isinstance(info, dict) else info for token_name, info in special_tokens.items() if token_name in token_name_to_id and token_name_to_id[token_name] != -1} + self.vocab_special_size: int = len(self.added_tokens_list) + len(self.special_tokens_map) def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]: tokenizer = self.sentencepiece_tokenizer if self.vocabtype == "bpe": - from transformers.models.gpt2 import tokenization_gpt2 - byte_encoder = tokenization_gpt2.bytes_to_unicode() - byte_decoder = {v: k for k, v in byte_encoder.items()} - for i, item in enumerate(tokenizer): - text: bytes - text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]]) - score: float = -i - yield text, score + from transformers.models.gpt2 import tokenization_gpt2 + byte_encoder = tokenization_gpt2.bytes_to_unicode() + byte_decoder = {v: k for k, v in byte_encoder.items()} + for i, item in enumerate(tokenizer): + text: bytes + text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]]) + score: float = -i + yield text, score else: - for i in range(tokenizer.vocab_size()): - text: bytes - if tokenizer.is_unknown(i): - text = " \u2047 ".encode("utf-8") - elif tokenizer.is_control(i): - text = b"" - elif tokenizer.is_byte(i): - piece = tokenizer.id_to_piece(i) - if len(piece) != 6: - raise Exception(f"Invalid token: {piece}") - byte_value = int(piece[3:-1], 16) - text = struct.pack("B", byte_value) - else: - text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") - score: float = tokenizer.get_score(i) - yield text, score + special_tokens = [tokenizer.bos_id(), tokenizer.eos_id(), tokenizer.pad_id()] + for i in range(tokenizer.vocab_size()): + text: bytes + if tokenizer.is_unknown(i): + text = self.special_tokens_map.get(i, " \u2047 ").encode("utf-8") + elif i in special_tokens: + text = self.special_tokens_map.get(i, "").encode("utf-8") + elif tokenizer.is_control(i): + text = b"" + elif tokenizer.is_byte(i): + piece = tokenizer.id_to_piece(i) + if len(piece) != 6: + raise Exception(f"Invalid token: {piece}") + byte_value = int(piece[3:-1], 16) + text = struct.pack("B", byte_value) + else: + text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") + score: float = tokenizer.get_score(i) + yield text, score def added_tokens(self) -> Iterable[Tuple[bytes, float]]: for text in self.added_tokens_list: @@ -303,6 +318,12 @@ class SentencePieceVocab: yield from self.sentencepiece_tokens() yield from self.added_tokens() + def all_special_tokens(self) -> Iterable[int]: + for token_id in self.special_tokens_map.keys(): + yield token_id + for i in range(len(self.added_tokens_list)): + yield self.vocab_size_base + i + def __repr__(self) -> str: return f"" @@ -310,11 +331,16 @@ class SentencePieceVocab: class GGMLVocab: def __init__(self, tokens: List[Tuple[bytes, float]]): self.tokens = tokens + self.special_tokens = [] self.vocab_size = len(tokens) + self.vocab_special_size = 0 def all_tokens(self) -> Iterable[Tuple[bytes, float]]: return self.tokens + def all_special_tokens(self) -> Iterable[int]: + return self.special_tokens + def __repr__(self) -> str: return f"" @@ -1066,8 +1092,9 @@ class OutputFile: def write_file_header(self, params: Params, file_type: GGMLFileType) -> None: self.fout.write(b"ggjt"[::-1]) # magic values = [ - 1, # file version + 4, # file version params.n_vocab, + params.n_vocab_sp, params.n_embd, params.n_mult, params.n_head, @@ -1089,11 +1116,14 @@ class OutputFile: self.fout.write(struct.pack("i", len(text))) self.fout.write(text) self.fout.write(struct.pack("f", score)) + for token_id in vocab.all_special_tokens(): + self.fout.write(struct.pack("i", token_id)) @staticmethod def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: of = OutputFile(fname_out) - params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0) + params = Params(n_vocab=vocab.vocab_size, n_vocab_sp=vocab.vocab_special_size, n_embd=0, n_mult=0, + n_head=1, n_layer=0) of = OutputFile(fname_out) of.write_file_header(params, file_type=GGMLFileType.AllF32) of.write_vocab(vocab) @@ -1249,8 +1279,10 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab: f"Could not find tokenizer.model in {path} or its parent; " "if it's in another directory, pass the directory as --vocab-dir") added_tokens_path = path.parent / "added_tokens.json" + special_tokens_path = path.parent / "special_tokens_map.json" + tokenizer_config_path = path.parent / "tokenizer_config.json" print(f"Loading vocab file {path}") - return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, + return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else tokenizer_config_path if tokenizer_config_path.exists() else None, vocabtype) @@ -1313,6 +1345,7 @@ def main(args_in: Optional[List[str]] = None) -> None: vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab = load_vocab(vocab_dir, args.vocabtype) params = Params.load(model_plus) + params.n_vocab_sp = vocab.vocab_special_size model = model_plus.model model = do_necessary_conversions(model, params) output_type = pick_output_type(model, args.outtype) diff --git a/llama-util.h b/llama-util.h index 6e9e39ddb..a3ec8a501 100644 --- a/llama-util.h +++ b/llama-util.h @@ -14,6 +14,8 @@ #include #include +#include +#include #include #ifdef __has_include @@ -541,4 +543,166 @@ struct llama_ctx_buffer { typedef llama_buffer llama_ctx_buffer; #endif +struct llama_trie_node { + llama_trie_node(): is_terminator(false) {} + + std::unordered_map children; + bool is_terminator; +}; + +// Trie in C++. Creates a Trie out of a list of words. The trie is used to split on multiple delimiters in one pass +// Ported from: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/tokenization_utils.py#L52 +struct llama_trie { +public: + llama_trie(): root_(new llama_trie_node()) {} + + void add(const std::string & word) { + if (word.empty()) { + return; + } + + llama_trie_node *ref = root_; + for (char c : word) { + if (ref->children.find(c) == ref->children.end()) { + ref->children[c] = new llama_trie_node(); + } + ref = ref->children[c]; + } + ref->is_terminator = true; + } + + // Will look for the words added to the trie within `text`. Output is the boundaries of the words found. + // Note that this trie will match the longest possible word first! + std::vector split(const std::string & text) const { + std::map states; + std::vector offsets{0}; + + int skip = 0; + for (int current = 0; current < text.size(); current++) { + char current_char = text[current]; + if (skip > 0 && current < skip) { + // Prevents the lookahead for matching twice + // like extra_id_100 and id_100 + continue; + } + + // Whenever we found a match, we need to drop everything + // this is a greedy algorithm, it will match on the first found token + bool reset = false; + + // In this case, we already have partial matches (But unfinished) + for (auto state = states.begin(); state != states.end(); ) { + int start = state->first; + llama_trie_node *trie_pointer = state->second; + if (trie_pointer->is_terminator) { + // This is a final match, we need to reset and + // store the results in `offsets`. + + // Lookahead to match longest first + // Important in case of extra_id_1 vs extra_id_100 + // Here we are also actively looking for other earlier partial + // matches + // "[CLS]", "L", we need to match CLS even if L is special + int end = 0; + for (const auto & look : states) { + int lookstart = look.first; + llama_trie_node *looktrie_pointer = look.second; + int lookahead_index = 0; + if (lookstart > start) { + // This partial match is later, we can stop looking + break; + } + if (lookstart < start) { + // This partial match is earlier, the trie pointer + // was already updated, so index is + 1 + lookahead_index = current + 1; + end = current + 1; + } else { + // Here lookstart == start and + // looktrie_pointer == trie_pointer + // It wasn't updated yet so indices are current ones + lookahead_index = current; + end = current; + } + char next_char = lookahead_index < text.size() ? text[lookahead_index] : '\0'; + if (looktrie_pointer->is_terminator) { + start = lookstart; + end = lookahead_index; + skip = lookahead_index; + } + + auto looktrie_pointer_it = looktrie_pointer->children.find(next_char); + while (looktrie_pointer_it != looktrie_pointer->children.end()) { + looktrie_pointer = looktrie_pointer_it->second; + lookahead_index++; + if (looktrie_pointer->is_terminator) { + start = lookstart; + end = lookahead_index; + skip = lookahead_index; + } + + if (lookahead_index == text.size()) { + // End of string + break; + } + next_char = text[lookahead_index]; + looktrie_pointer_it = looktrie_pointer->children.find(next_char); + } + } + + offsets.push_back(start); + offsets.push_back(end); + reset = true; + break; + } + + auto trie_pointer_it = trie_pointer->children.find(current_char); + if (trie_pointer_it != trie_pointer->children.end()) { + // The current character being looked at has a match within the trie + // update the pointer (it will be stored back into states later). + trie_pointer = trie_pointer_it->second; + states[start] = trie_pointer; + ++state; + } else { + // The new character has not match in the trie, we need + // to stop keeping track of this partial match. + state = states.erase(state); + } + } + + if (reset) { + // Clear the full start (we found a real match) + states.clear(); + } + + // If this character is a starting character within the trie + // start keeping track of this partial match. + auto children_it = root_->children.find(current_char); + if (current >= skip && children_it != root_->children.end()) { + states[current] = children_it->second; + } + } + + // We have a cut at the end with states. + for (const auto & state : states) { + int start = state.first; + llama_trie_node *trie_pointer = state.second; + if (trie_pointer->is_terminator) { + // This is a final match, we need to reset and + // store the results in `offsets`. + int end = text.size(); + offsets.push_back(start); + offsets.push_back(end); + break; + } + } + + offsets.push_back(text.size()); + return offsets; + } + +private: + llama_trie_node *root_; +}; + #endif diff --git a/llama.cpp b/llama.cpp index 39aefd499..9908065ee 100644 --- a/llama.cpp +++ b/llama.cpp @@ -181,6 +181,7 @@ static const std::map & VRAM_REQ_SCRATCH_PER_CONTEXT() // default hparams (LLaMA 7B) struct llama_hparams { uint32_t n_vocab = 32000; + uint32_t n_vocab_sp = 0; uint32_t n_ctx = 512; // this is provided as user input? uint32_t n_embd = 4096; uint32_t n_mult = 256; @@ -277,6 +278,11 @@ struct llama_vocab { std::unordered_map token_to_id; std::vector id_to_token; + + llama_trie special_token_trie; + std::unordered_map special_token_to_id; + std::vector special_tokens; + size_t max_special_token_length; }; struct llama_model { @@ -494,6 +500,7 @@ enum llama_file_version { LLAMA_FILE_VERSION_GGJT_V1, // added padding LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format + LLAMA_FILE_VERSION_GGJT_V4, // improved support for added/special tokens }; struct llama_file_loader { @@ -531,6 +538,7 @@ struct llama_file_loader { case 1: file_version = LLAMA_FILE_VERSION_GGJT_V1; return; case 2: file_version = LLAMA_FILE_VERSION_GGJT_V2; return; case 3: file_version = LLAMA_FILE_VERSION_GGJT_V3; return; + case 4: file_version = LLAMA_FILE_VERSION_GGJT_V4; return; } } @@ -539,6 +547,7 @@ struct llama_file_loader { } void read_hparams() { hparams.n_vocab = file.read_u32(); + hparams.n_vocab_sp = file_version >= LLAMA_FILE_VERSION_GGJT_V4 ? file.read_u32() : 0; hparams.n_embd = file.read_u32(); hparams.n_mult = file.read_u32(); hparams.n_head = file.read_u32(); @@ -566,6 +575,21 @@ struct llama_file_loader { tok_score.tok = std::move(word); tok_score.score = score; } + + vocab.special_token_to_id.reserve(hparams.n_vocab_sp); + + for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) { + uint32_t token_id = file.read_u32(); + const auto & token = vocab.id_to_token[token_id].tok; + + vocab.special_token_trie.add(token); + vocab.special_tokens.push_back(token_id); + vocab.special_token_to_id[token] = token_id; + + if (vocab.max_special_token_length < token.size()) { + vocab.max_special_token_length = token.size(); + } + } } void read_tensor_metadata(llama_load_tensors_map & tensors_map) { while (file.tell() < file.size) { @@ -631,6 +655,7 @@ struct llama_file_saver { void write_hparams(enum llama_ftype new_ftype) { const llama_hparams & hparams = any_file_loader->hparams; file.write_u32(hparams.n_vocab); + file.write_u32(hparams.n_vocab_sp); file.write_u32(hparams.n_embd); file.write_u32(hparams.n_mult); file.write_u32(hparams.n_head); @@ -649,6 +674,10 @@ struct llama_file_saver { file.write_raw(token_score.tok.data(), token_score.tok.size()); file.write_raw(&token_score.score, sizeof(token_score.score)); } + uint32_t n_vocab_sp = any_file_loader->hparams.n_vocab_sp; + for (uint32_t i = 0; i < n_vocab; i++) { + file.write_u32(any_file_loader->vocab.special_tokens[i]); + } } void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) { switch (new_type) { @@ -975,7 +1004,8 @@ static const char *llama_file_version_name(llama_file_version version) { case LLAMA_FILE_VERSION_GGMF_V1: return "ggmf v1 (old version with no mmap support)"; case LLAMA_FILE_VERSION_GGJT_V1: return "ggjt v1 (pre #1405)"; case LLAMA_FILE_VERSION_GGJT_V2: return "ggjt v2 (pre #1508)"; - case LLAMA_FILE_VERSION_GGJT_V3: return "ggjt v3 (latest)"; + case LLAMA_FILE_VERSION_GGJT_V3: return "ggjt v3 (pre #1931)"; + case LLAMA_FILE_VERSION_GGJT_V4: return "ggjt v4 (latest)"; } return "unknown"; @@ -1960,18 +1990,20 @@ struct llama_sp_bigram { struct llama_tokenizer { llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {} - void tokenize(const std::string & text, std::vector & output) { + void tokenize(const char * text, size_t len, std::vector & output) { + symbols_.clear(); + // split string into utf8 chars int index = 0; size_t offs = 0; - while (offs < text.size()) { + while (offs < len) { llama_sp_symbol sym; - size_t char_len = std::min(text.size() - offs, utf8_len(text[offs])); - sym.text = text.c_str() + offs; + size_t char_len = std::min(len - offs, utf8_len(text[offs])); + sym.text = text + offs; sym.n = char_len; offs += char_len; sym.prev = index - 1; - sym.next = offs == text.size() ? -1 : index + 1; + sym.next = offs == len ? -1 : index + 1; index++; symbols_.emplace_back(sym); } @@ -2074,7 +2106,33 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co output.push_back(llama_token_bos()); } - tokenizer.tokenize(text, output); + if (vocab.special_token_to_id.empty()) { + tokenizer.tokenize(text.c_str(), text.size(), output); + return output; + } + + auto offsets = vocab.special_token_trie.split(text); + int start = 0; + for (int end : offsets) { + if (start >= end) { + continue; + } + + size_t part_length = end - start; + //printf("\"%.*s\"\n", (int) part_length, text.c_str() + start); + + if (vocab.max_special_token_length < part_length) { + tokenizer.tokenize(text.c_str() + start, part_length, output); + } else { + auto token_it = vocab.special_token_to_id.find(std::string(text.c_str() + start, part_length)); + if (token_it != vocab.special_token_to_id.end()) { + output.push_back(token_it->second); + } else { + tokenizer.tokenize(text.c_str() + start, part_length, output); + } + } + start = end; + } return output; } @@ -4212,6 +4270,10 @@ llama_token llama_token_nl() { return 13; } +bool llama_is_special_token(const struct llama_context *ctx, llama_token token) { + return std::find(ctx->vocab.special_tokens.begin(), ctx->vocab.special_tokens.end(), token) != ctx->vocab.special_tokens.end(); +} + struct llama_timings llama_get_timings(struct llama_context * ctx) { struct llama_timings result = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, diff --git a/llama.h b/llama.h index fa1977f2d..9ece944d9 100644 --- a/llama.h +++ b/llama.h @@ -40,7 +40,7 @@ #define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' -#define LLAMA_FILE_VERSION 3 +#define LLAMA_FILE_VERSION 4 #define LLAMA_FILE_MAGIC LLAMA_FILE_MAGIC_GGJT #define LLAMA_FILE_MAGIC_UNVERSIONED LLAMA_FILE_MAGIC_GGML #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN @@ -373,6 +373,8 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_nl(); // next-line + LLAMA_API bool llama_is_special_token(const struct llama_context * ctx, llama_token token); + // Grammar // LLAMA_API struct llama_grammar * llama_grammar_init(