From 370a95f5241df536def0101acf0a7c8ab9a2ca47 Mon Sep 17 00:00:00 2001 From: goerch Date: Sat, 19 Aug 2023 14:39:33 +0200 Subject: [PATCH] Improve token type support - Added @klosax code to convert.py - Improved token type support in vocabulary --- convert.py | 32 +++++++++--- llama.cpp | 91 ++++++++++++++--------------------- models/ggml-vocab-llama.gguf | Bin 467382 -> 595423 bytes 3 files changed, 59 insertions(+), 64 deletions(-) diff --git a/convert.py b/convert.py index f6237579d..da1e42df0 100755 --- a/convert.py +++ b/convert.py @@ -261,12 +261,12 @@ class BpeVocab: for i, item in enumerate(tokenizer): text: bytes = item.encode("utf-8") score: float = -i - yield text, score + yield text, score, 4 def added_tokens(self) -> Iterable[Tuple[bytes, float]]: for text in self.added_tokens_list: score = -1000.0 - yield text.encode("utf-8"), score + yield text.encode("utf-8"), score, 4 def all_tokens(self) -> Iterable[Tuple[bytes, float]]: yield from self.bpe_tokens() @@ -303,12 +303,28 @@ class SentencePieceVocab: piece = tokenizer.id_to_piece(i) text: bytes = piece.encode("utf-8") score: float = tokenizer.get_score(i) - yield text, score + + toktype = 1 # defualt to normal token type + if tokenizer.is_unknown(i): + toktype = 2 + if tokenizer.is_control(i): + toktype = 3 + + # NOTE: I think added_tokens are user defined. + # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto + # if tokenizer.is_user_defined(i): toktype = 4 + + if tokenizer.is_unused(i): + toktype = 5 + if tokenizer.is_byte(i): + toktype = 6 + + yield text, score, toktype def added_tokens(self) -> Iterable[Tuple[bytes, float]]: for text in self.added_tokens_list: score = -1000.0 - yield text.encode("utf-8"), score + yield text.encode("utf-8"), score, 4 def all_tokens(self) -> Iterable[Tuple[bytes, float]]: yield from self.sentencepiece_tokens() @@ -720,16 +736,16 @@ class OutputFile: def add_meta_vocab(self, vocab: Vocab) -> None: tokens = [] scores = [] - for text, score in vocab.all_tokens(): + toktypes = [] + for text, score, toktype in vocab.all_tokens(): tokens.append(text) scores.append(score) + toktypes.append(toktype) self.gguf.add_tokenizer_model("llama") self.gguf.add_token_list(tokens) self.gguf.add_token_scores(scores) - #self.gguf.add_token_types(toktypes) # TODO: add this - - # TODO: added / special tokens + self.gguf.add_token_types(toktypes) def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: n_elements = 1 diff --git a/llama.cpp b/llama.cpp index 86c943fc6..e4347f04c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -772,15 +772,16 @@ struct llama_vocab { using id = int32_t; using token = std::string; - struct token_score { + struct token_data { token tok; float score; + int toktype; }; llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; std::unordered_map token_to_id; - std::vector id_to_token; + std::vector id_to_token; // default LLaMA special tokens id special_bos_id = 1; @@ -1499,17 +1500,25 @@ static void llama_model_load_internal( const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + const int toktype_idx = gguf_find_key(ctx, "tokenizer.ggml.token_type"); + if (toktype_idx == -1) { + throw std::runtime_error("cannot find token type list in GGUF file\n"); + } + + const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + for (uint32_t i = 0; i < hparams.n_vocab; i++) { std::string word = gguf_get_arr_str(ctx, token_idx, i); vocab.token_to_id[word] = i; - auto & tok_score = vocab.id_to_token[i]; - tok_score.tok = std::move(word); - tok_score.score = scores[i]; + auto & token_data = vocab.id_to_token[i]; + token_data.tok = std::move(word); + token_data.score = scores[i]; + token_data.toktype = toktypes[i]; // determine the newline token: 0x0A == 10 == '\n' - if (tok_score.tok == "<0x0A>") { + if (token_data.tok == "<0x0A>") { vocab.linefeed_id = i; } } @@ -2337,68 +2346,38 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { return vocab.type; } -static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { - return token >= 259; - } - - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) { - return token >= 95; - } - - return false; +static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].toktype == 1; } -static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) { - return token == vocab.special_bos_id; +static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].toktype == 2; } -static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) { - return token == vocab.special_eos_id; +static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].toktype == 3; } -static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { - return token == llama_is_bos_token(vocab, token) || token == llama_is_eos_token(vocab, token); - } - - // TODO: improve? - return false; +static bool llama_is_bos_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(llama_is_control_token(vocab, id)); + return id == vocab.special_bos_id; } -static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { - return token == 0; - } - - // TODO: improve? - return false; +static bool llama_is_eos_token(const llama_vocab & vocab, llama_token id ) { + GGML_ASSERT(llama_is_control_token(vocab, id)); + return id == vocab.special_eos_id; } -static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) { - GGML_UNUSED(vocab); - GGML_UNUSED(token); - // TODO: improve? - return false; +static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].toktype == 4; } -static bool llama_is_unused_token(const llama_vocab & vocab, llama_token token) { - GGML_UNUSED(vocab); - GGML_UNUSED(token); - // TODO: improve? - return false; +static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].toktype == 5; } -static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { - return 3 <= token && token < 259; - } - - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) { - return 1 <= token && token < 95; - } - - return false; +static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].toktype == 6; } static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) { @@ -2587,12 +2566,12 @@ private: return; } - const auto &tok_score = vocab_.id_to_token[(*token).second]; + const auto &tok_data = vocab_.id_to_token[(*token).second]; llama_sp_bigram bigram; bigram.left = left; bigram.right = right; - bigram.score = tok_score.score; + bigram.score = tok_data.score; bigram.size = text.size(); work_queue_.push(bigram); diff --git a/models/ggml-vocab-llama.gguf b/models/ggml-vocab-llama.gguf index c50db67dc52023444e607bde7b684da5d631e564..63bfaf672f382c0f5bbcffe54736e2698ef3ac55 100644 GIT binary patch delta 129117 zcmeI(y=qf&0Egk6endOu0tyZe&fUZ-&|ZM(YL-yMPo#^2oI%jp>u_?>$;k`xR{Td@ zU6R1NJi?)fG^c&jlRW2F{Py$KuP=92mzR$oOpCvNR-ae*r<-3No=qn|?tQ%bVLg3X zPut^Y9LM808OP~3mfzRYjmNj=^YMAtubypQK7F@&b8&U`;`zm&zdpHs`+9RW{_%AC z!0#7}qw(SR-?1DY4|B|idY~O2XPo~)-^p%|Gan}7ZYLk-p40DsInLjH0iOSTegLq~BvViN~^N#7T`brD+};nsUQ0v|CRc&|M6d`ANwEwmHM&&@gM({1^AEu$^!hye`Nvw z1|CRc&|M6d`ANwEwmHM&&@n5MQ`yc=DUs-_v_^&L$fBaV#;6MH=3-Dj5ANwEw zmHM&&@n5MQ`yc<6`mz7cd}0{q8+WdZ)KmIEV@L#DP`yc<6`mz7N#7T`brD+};nsUQ0v|CRc&|M6d`ANwEw zmHM&&@gM({1^AEu$^!hye`Nvw1|CRc&|M6d`ANwEwmHM&&@n5MQ`yc=DUs-_v z_^&L$fBaV#;6MH=3-Dj5ANwEwmHM&&@n5MQ`yc<6`mz7cd}0{q8+WdZ)< zzp?=TmHM&&@n5MQ`yc<6`mz7KmIEV@L#DP`yc<6 z`mz7N#7T`br zD+};nsUQ0v|CRc&|M6d`ANwEwmHM&&@gM({1^AEu$^!hye`Nvw1|CRc&|M6d` zANwEwmHM&&@n5MQ`yc=DUs-_v_^&L$fBaV#;6MH=3-Dj5ANwEwmHM&&@n5MQ`yc<6 z`mz7cd}0{q8+WdZ)KmIEV@L#DP`yc<6`mz7>&hswF_@E)WzA7_^X3G`_5?-{W&&bnAZ7t#Rv>2E Kp1{Z+rT_p&>=0T2