From 7e1041a730fdfe0f5fb3b66280f1b8be0c923ee5 Mon Sep 17 00:00:00 2001 From: Jakub Horak Date: Fri, 17 Mar 2023 17:35:41 +0100 Subject: [PATCH] Implement non-greedy tokenizer that tries to maximize token lengths --- utils.cpp | 70 ++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/utils.cpp b/utils.cpp index 26e313d5f..7539edd86 100644 --- a/utils.cpp +++ b/utils.cpp @@ -275,40 +275,56 @@ std::vector gpt_tokenize(const gpt_vocab & vocab, const std::stri return tokens; } +// TODO: Calculate this constant from the vocabulary +#define MAX_TOKEN_LEN 18 +// SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece std::vector llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos) { - //auto res = gpt_tokenize(vocab, text); - - //if (bos) { - // res.insert(res.begin(), 1); // TODO: replace with vocab.bos - //} - std::vector res; + std::vector score; + std::vector prev; + int len = text.length(); + + score.resize(len + 1); + prev.resize(len + 1); + + // Forward pass + for (int i = 0; i < len; i++) { + int max_len = std::min(len - i, MAX_TOKEN_LEN); + for (int sub_len = 1; sub_len <= len - i; sub_len++) { + auto sub = text.substr(i, sub_len); + auto token = vocab.token_to_id.find(sub); + if (token != vocab.token_to_id.end()) { + int token_score = sub.length() * sub.length(); + int local_score = score[i] + token_score; + int next = i + sub_len; + if (score[next] < local_score) { + score[next] = local_score; + prev[next] = (*token).second; + } + } + } + } + + // Backward pass + int i = len; + while (i > 0) { + gpt_vocab::id token_id = prev[i]; + if (token_id == 0) { + // TODO: Return error or something more meaningful + printf("failed to tokenize string!\n"); + break; + } + res.push_back(token_id); + auto token = (*vocab.id_to_token.find(token_id)).second; + i -= token.length(); + } if (bos) { res.push_back(1); // TODO: replace with vocab.bos } - //find the longest token that matches the text - int pos = 0; - while (true) { - int l = 0; - int t = 0; - for (const auto & kv : vocab.id_to_token) { - if (kv.second.size() < l) continue; - if (kv.second.size() > text.size() - pos) continue; - if (text.substr(pos, kv.second.size()) == kv.second) { - l = kv.second.size(); - t = kv.first; - } - } - - if (l == 0) { - break; - } - - res.push_back(t); - pos += l; - } + // Pieces are in reverse order so correct that + std::reverse(res.begin(), res.end()); return res; }