Streamlining code and adding some more assertions

Important change: I'm classifying added tokens as control tokens now for BPE.
This commit is contained in:
goerch 2023-09-19 19:04:48 +02:00
parent a4e9448ef0
commit 17ca832717
5 changed files with 42 additions and 42 deletions

View file

@ -161,12 +161,8 @@ byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()} byte_decoder = {v: k for k, v in byte_encoder.items()}
for i in range(vocab_size): for i in range(vocab_size):
text = reverse_vocab[i] tokens.append(reverse_vocab[i])
tokens.append(text)
scores.append(0.0) # dummy scores.append(0.0) # dummy
if text in byte_decoder:
toktypes.append(gguf.TokenType.BYTE)
else:
toktypes.append(gguf.TokenType.NORMAL) toktypes.append(gguf.TokenType.NORMAL)
gguf_writer.add_token_list(tokens) gguf_writer.add_token_list(tokens)

View file

@ -343,19 +343,13 @@ class BpeVocab:
byte_encoder = tokenization_gpt2.bytes_to_unicode() byte_encoder = tokenization_gpt2.bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()} byte_decoder = {v: k for k, v in byte_encoder.items()}
score = 0.0
for i, _ in enumerate(tokenizer): for i, _ in enumerate(tokenizer):
text = reverse_vocab[i] yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
if text in byte_decoder:
toktype = gguf.TokenType.BYTE
else:
toktype = gguf.TokenType.NORMAL
yield text, score, toktype
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
for text in self.added_tokens_list: for text in self.added_tokens_list:
score = -1000.0 score = -1000.0
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
yield from self.bpe_tokens() yield from self.bpe_tokens()

View file

@ -3884,6 +3884,10 @@ static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE;
} }
static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED;
}
static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
GGML_ASSERT(llama_is_byte_token(vocab, id)); GGML_ASSERT(llama_is_byte_token(vocab, id));
const auto& token_data = vocab.id_to_token.at(id); const auto& token_data = vocab.id_to_token.at(id);
@ -7224,15 +7228,11 @@ static std::string llama_decode_text(const std::string& text) {
// does not write null-terminator to buf // does not write null-terminator to buf
int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) { int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) {
if (0 <= token && token < llama_model_n_vocab(model)) { if (0 <= token && token < llama_model_n_vocab(model)) {
switch (llama_vocab_get_type(model->vocab)) {
case LLAMA_VOCAB_TYPE_SPM: {
if (llama_is_normal_token(model->vocab, token)) { if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text; std::string result = model->vocab.id_to_token[token].text;
if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) {
llama_unescape_whitespace(result); llama_unescape_whitespace(result);
} else if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_BPE) {
result = llama_decode_text(result);
} else {
GGML_ASSERT(false);
}
if (length < (int) result.length()) { if (length < (int) result.length()) {
return -result.length(); return -result.length();
} }
@ -7242,29 +7242,39 @@ int llama_token_to_piece_with_model(const struct llama_model * model, llama_toke
if (length < 3) { if (length < 3) {
return -3; return -3;
} }
buf[0] = '\xe2'; memcpy(buf, "\xe2\x96\x85", 3);
buf[1] = '\x96';
buf[2] = '\x85';
return 3; return 3;
} else if (llama_is_control_token(model->vocab, token)) { } else if (llama_is_control_token(model->vocab, token)) {
; ;
} else if (llama_is_byte_token(model->vocab, token)) { } else if (llama_is_byte_token(model->vocab, token)) {
if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) {
if (length < 1) { if (length < 1) {
return -1; return -1;
} }
buf[0] = llama_token_to_byte(model->vocab, token); buf[0] = llama_token_to_byte(model->vocab, token);
return 1; return 1;
} else if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_BPE) { } else {
std::string result = llama_decode_text(model->vocab.id_to_token[token].text); GGML_ASSERT(false);
}
break;
}
case LLAMA_VOCAB_TYPE_BPE: {
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
result = llama_decode_text(result);
if (length < (int) result.length()) { if (length < (int) result.length()) {
return -result.length(); return -result.length();
} }
memcpy(buf, result.c_str(), result.length()); memcpy(buf, result.c_str(), result.length());
return result.length(); return result.length();
} else if (llama_is_control_token(model->vocab, token)) {
;
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }
break;
}
default:
GGML_ASSERT(false);
} }
} }
return 0; return 0;

Binary file not shown.

Binary file not shown.