add in wordpiece tokenizer
This commit is contained in:
parent
0051c82d52
commit
59c1829b0c
4 changed files with 279 additions and 19 deletions
|
@ -1594,26 +1594,37 @@ class BertModel(Model):
|
|||
path = self.dir_model
|
||||
added_tokens_path = self.dir_model if self.dir_model.exists() else None
|
||||
|
||||
# use huggingface vocab to get all tokens
|
||||
vocab = HfVocab(path, added_tokens_path)
|
||||
tokens, scores, toktypes = zip(*vocab.all_tokens())
|
||||
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
|
||||
# for some reason set(toktypes) = {1, 3} so we need to compress it
|
||||
all_types, toktypes1 = np.unique(toktypes, return_inverse=True)
|
||||
n_token_types, toktypes1 = len(all_types), toktypes1.tolist()
|
||||
# we need this to validate the size of the token_type embeddings
|
||||
# though currently we are passing all zeros to the token_type embeddings
|
||||
n_token_types = len(set(toktypes))
|
||||
self.gguf_writer.add_uint32("tokenizer.ggml.token_type_count", n_token_types)
|
||||
|
||||
# convert tokens to SPM style
|
||||
tokens = [
|
||||
(t[2:] if t.startswith(b"##") else b"\xe2\x96\x81" + t) for t in tokens
|
||||
]
|
||||
# convert to phantom space vocab
|
||||
def phantom(tok, typ):
|
||||
if tok.startswith(b'[') and tok.endswith(b']'):
|
||||
return tok
|
||||
elif tok.startswith(b"##"):
|
||||
return tok[2:]
|
||||
else:
|
||||
return b"\xe2\x96\x81" + tok
|
||||
tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)]
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("llama")
|
||||
# set up bos and eos tokens (cls and sep)
|
||||
self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
|
||||
self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
|
||||
|
||||
# add vocab to gguf
|
||||
self.gguf_writer.add_tokenizer_model("bert")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes) # ignore types for now (all zero)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
# handle special tokens
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
|
|
@ -87,7 +87,17 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
const int n_embd = llama_n_embd(model);
|
||||
const auto * embeddings = llama_get_embeddings(ctx);
|
||||
auto * embeddings = llama_get_embeddings(ctx);
|
||||
|
||||
// l2-normalize embeddings
|
||||
float norm = 0;
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
norm += embeddings[i] * embeddings[i];
|
||||
}
|
||||
norm = sqrt(norm);
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
embeddings[i] /= norm;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
printf("%f ", embeddings[i]);
|
||||
|
|
252
llama.cpp
252
llama.cpp
|
@ -2860,6 +2860,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
|
|||
switch (type) {
|
||||
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
|
||||
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
|
||||
case LLAMA_VOCAB_TYPE_WPM: return "WPM";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
@ -3033,6 +3034,7 @@ static void llm_load_hparams(
|
|||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
|
||||
hparams.causal_attn = false;
|
||||
|
||||
switch (hparams.n_embd) {
|
||||
case 384: // MiniLM
|
||||
|
@ -3248,6 +3250,16 @@ static void llm_load_vocab(
|
|||
vocab.special_unk_id = -1;
|
||||
vocab.special_sep_id = -1;
|
||||
vocab.special_pad_id = -1;
|
||||
} else if (tokenizer_name == "bert") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_WPM;
|
||||
|
||||
// default special tokens
|
||||
vocab.special_bos_id = 101;
|
||||
vocab.special_eos_id = 102;
|
||||
vocab.special_unk_id = 100;
|
||||
vocab.special_sep_id = -1;
|
||||
vocab.special_pad_id = -1;
|
||||
vocab.add_space_prefix = false;
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
|
||||
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
|
||||
|
@ -3275,11 +3287,9 @@ static void llm_load_vocab(
|
|||
|
||||
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
|
||||
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
||||
try {
|
||||
vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
|
||||
} catch (const std::exception & e) {
|
||||
LLAMA_LOG_WARN("%s: model vocab missing newline token: %s\n", __func__, e.what());
|
||||
}
|
||||
} else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
|
||||
vocab.linefeed_id = vocab.special_pad_id;
|
||||
} else {
|
||||
const std::vector<int> ids = llama_tokenize_internal(vocab, "\u010A", false);
|
||||
GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
|
||||
|
@ -5725,11 +5735,14 @@ struct llm_build_context {
|
|||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
// get input vectors with right size
|
||||
struct ggml_tensor * inp_type = ggml_view_1d(ctx0, lctx.inp_type, n_tokens, 0);
|
||||
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
|
||||
struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0);
|
||||
|
||||
// construct input embeddings (token, type, position)
|
||||
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
|
||||
struct ggml_tensor * inp_type = ggml_view_1d(ctx0, lctx.inp_type, n_tokens, 0);
|
||||
inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.type_embd, inp_type), inpL);
|
||||
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
|
||||
inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
|
||||
cb(inpL, "inp_embd", -1);
|
||||
|
||||
|
@ -5794,7 +5807,6 @@ struct llm_build_context {
|
|||
cur = inpL;
|
||||
|
||||
// pooling
|
||||
struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0);
|
||||
cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
|
||||
cb(cur, "result_embed", -1);
|
||||
|
||||
|
@ -7655,6 +7667,9 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
|||
GGML_ASSERT(false);
|
||||
return unicode_to_bytes_bpe(token_data.text);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_WPM: {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
@ -7667,6 +7682,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
|
|||
const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
|
||||
return vocab.token_to_id.at(buf);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_WPM:
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
return vocab.token_to_id.at(bytes_to_unicode_bpe(ch));
|
||||
}
|
||||
|
@ -8137,6 +8153,207 @@ private:
|
|||
llm_bigram_bpe::queue work_queue;
|
||||
};
|
||||
|
||||
struct llm_tokenizer_wpm {
|
||||
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
|
||||
|
||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||
auto * token_map = &vocab.token_to_id;
|
||||
|
||||
// normalize and split by whitespace
|
||||
std::vector<std::string> words = preprocess(text);
|
||||
|
||||
// bos token prepended already
|
||||
|
||||
// find the longest tokens that form the words
|
||||
for (const std::string &word : words) {
|
||||
// skip empty words
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
// prepend phantom space
|
||||
std::string word1 = "\xe2\x96\x81" + word;
|
||||
int n = word1.size();
|
||||
|
||||
// we're at the start of a new word
|
||||
int i = 0;
|
||||
bool match_any = false;
|
||||
|
||||
// move through character position in word
|
||||
while (i < n) {
|
||||
// loop through possible match length
|
||||
bool match = false;
|
||||
for (int j = n; j > i; j--) {
|
||||
auto it = token_map->find(word1.substr(i, j - i));
|
||||
if (it != token_map->end()) {
|
||||
output.push_back(it->second);
|
||||
match = true;
|
||||
match_any = true;
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// must be an unknown character
|
||||
if (!match) i++;
|
||||
}
|
||||
|
||||
// we didn't find any matches for this word
|
||||
if (!match_any) {
|
||||
output.push_back(vocab.special_unk_id);
|
||||
}
|
||||
}
|
||||
|
||||
// append eos token
|
||||
output.push_back(vocab.special_eos_id);
|
||||
}
|
||||
|
||||
std::vector<std::string> preprocess(const std::string & text) {
|
||||
std::string ori_str = text;
|
||||
ori_str = normalize(ori_str);
|
||||
uint64_t ori_size = ori_str.size();
|
||||
|
||||
// single punct / single symbol / single digit
|
||||
// baseline: add whitespace on the left and right of punct and chinese characters
|
||||
std::vector<std::string> words;
|
||||
std::string new_str = "";
|
||||
uint64_t i = 0;
|
||||
while (i < ori_size) {
|
||||
int utf_char_len = utf8_len(ori_str[i]);
|
||||
if ((utf_char_len == 1) && ispunct(ori_str[i])) {
|
||||
new_str += " ";
|
||||
new_str += ori_str[i];
|
||||
new_str += " ";
|
||||
i += 1;
|
||||
}
|
||||
else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) {
|
||||
new_str += " ";
|
||||
new_str += ori_str.substr(i, 3);
|
||||
new_str += " ";
|
||||
i += 3;
|
||||
}
|
||||
else {
|
||||
new_str += ori_str[i];
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// split by whitespace
|
||||
uint64_t l = 0;
|
||||
uint64_t r = 0;
|
||||
while (r < new_str.size()) {
|
||||
// if is whitespace
|
||||
if (isspace(new_str[r])) {
|
||||
if (r > l) words.push_back(new_str.substr(l, (r - l)));
|
||||
l = r + 1;
|
||||
r = l;
|
||||
}
|
||||
else {
|
||||
r += 1;
|
||||
}
|
||||
}
|
||||
if (r > l) {
|
||||
words.push_back(new_str.substr(l, (r - l)));
|
||||
}
|
||||
return words;
|
||||
}
|
||||
|
||||
std::string normalize(const std::string &text) {
|
||||
// TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98
|
||||
std::string text2 = strip_accents(text);
|
||||
for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i]))
|
||||
{
|
||||
char c = text2[i];
|
||||
if (c >= 'A' && c <= 'Z')
|
||||
text2[i] = c - 'A' + 'a';
|
||||
}
|
||||
return text2;
|
||||
}
|
||||
|
||||
bool is_chinese_char(const std::string& str) {
|
||||
int len = str.length();
|
||||
unsigned int codepoint = 0;
|
||||
int num_bytes = 0;
|
||||
int i = 0;
|
||||
unsigned char ch = static_cast<unsigned char>(str[i]);
|
||||
if (ch <= 0x7f) {
|
||||
codepoint = ch;
|
||||
num_bytes = 1;
|
||||
} else if ((ch >> 5) == 0x06) {
|
||||
codepoint = ch & 0x1f;
|
||||
num_bytes = 2;
|
||||
} else if ((ch >> 4) == 0x0e) {
|
||||
codepoint = ch & 0x0f;
|
||||
num_bytes = 3;
|
||||
} else if ((ch >> 3) == 0x1e) {
|
||||
codepoint = ch & 0x07;
|
||||
num_bytes = 4;
|
||||
}
|
||||
for (int j = 1; j < num_bytes; ++j) {
|
||||
if (i + j >= len) {
|
||||
return false; // incomplete UTF-8 character
|
||||
}
|
||||
unsigned char next_ch = static_cast<unsigned char>(str[i + j]);
|
||||
if ((next_ch >> 6) != 0x02) {
|
||||
return false; // invalid trailing byte
|
||||
}
|
||||
codepoint = (codepoint << 6) | (next_ch & 0x3f);
|
||||
}
|
||||
if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) ||
|
||||
(codepoint >= 0x3400 && codepoint <= 0x4DBF) ||
|
||||
(codepoint >= 0x20000 && codepoint <= 0x2A6DF) ||
|
||||
(codepoint >= 0x2A700 && codepoint <= 0x2B73F) ||
|
||||
(codepoint >= 0x2B740 && codepoint <= 0x2B81F) ||
|
||||
(codepoint >= 0x2B920 && codepoint <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
|
||||
(codepoint >= 0xF900 && codepoint <= 0xFAFF) ||
|
||||
(codepoint >= 0x2F800 && codepoint <= 0x2FA1F) ||
|
||||
(codepoint >= 0x3000 && codepoint <= 0x303F) ||
|
||||
(codepoint >= 0xFF00 && codepoint <= 0xFFEF)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string strip_accents(const std::string &inputString) {
|
||||
std::string resultString;
|
||||
std::map<std::string, char> accentMap = {
|
||||
{"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'},
|
||||
{"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'},
|
||||
{"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'},
|
||||
{"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'},
|
||||
{"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'},
|
||||
{"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'},
|
||||
{"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'},
|
||||
{"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'},
|
||||
{"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'},{ "ñ", 'n'},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < inputString.length();)
|
||||
{
|
||||
int len = utf8_len(inputString[i]);
|
||||
std::string curChar = inputString.substr(i, len);
|
||||
auto iter = accentMap.find(curChar);
|
||||
if (iter != accentMap.end())
|
||||
{
|
||||
resultString += iter->second;
|
||||
}
|
||||
else
|
||||
{
|
||||
resultString += curChar;
|
||||
}
|
||||
i += len;
|
||||
}
|
||||
|
||||
return resultString;
|
||||
}
|
||||
|
||||
static size_t utf8_len(char src) {
|
||||
const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
|
||||
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||
return lookup[highbits];
|
||||
}
|
||||
|
||||
const llama_vocab & vocab;
|
||||
};
|
||||
|
||||
typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{
|
||||
FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
|
||||
FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
|
||||
|
@ -8341,6 +8558,26 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||
}
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_WPM:
|
||||
{
|
||||
for (const auto & fragment: fragment_buffer)
|
||||
{
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
|
||||
{
|
||||
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||
|
||||
#ifdef PRETOKENIZERDEBUG
|
||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||
#endif
|
||||
llm_tokenizer_wpm tokenizer(vocab);
|
||||
tokenizer.tokenize(raw_text, output);
|
||||
}
|
||||
else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||
{
|
||||
output.push_back(fragment.token);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
}
|
||||
|
||||
return output;
|
||||
|
@ -11947,6 +12184,7 @@ static std::string llama_decode_text(const std::string & text) {
|
|||
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length) {
|
||||
if (0 <= token && token < llama_n_vocab(model)) {
|
||||
switch (llama_vocab_get_type(model->vocab)) {
|
||||
case LLAMA_VOCAB_TYPE_WPM:
|
||||
case LLAMA_VOCAB_TYPE_SPM: {
|
||||
// NOTE: we accept all unsupported token types,
|
||||
// suppressing them like CONTROL tokens.
|
||||
|
|
1
llama.h
1
llama.h
|
@ -61,6 +61,7 @@ extern "C" {
|
|||
enum llama_vocab_type {
|
||||
LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
|
||||
LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
|
||||
LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
|
||||
};
|
||||
|
||||
enum llama_token_type {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue