llama : prep new tokenizer support

This commit is contained in:
Georgi Gerganov 2023-08-23 19:08:44 +03:00
parent 6938c5f474
commit c3f8a6e49f
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

124
llama.cpp
View file

@ -106,6 +106,12 @@ static void llama_log_callback_default(llama_log_level level, const char * text,
// helpers
//
static size_t utf8_len(char src) {
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
return lookup[highbits];
}
static void zeros(std::ofstream & file, size_t n) {
char zero = 0;
for (size_t i = 0; i < n; ++i) {
@ -2948,47 +2954,41 @@ static std::string llama_unescape_whitespace(const std::string& word) {
return word;
}
static size_t utf8_len(char src) {
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
return lookup[highbits];
}
struct llama_sp_symbol {
using index = int;
index prev;
index next;
const char * text;
size_t n;
};
static_assert(std::is_trivially_copyable<llama_sp_symbol>::value, "llama_sp_symbol is not trivially copyable");
struct llama_sp_bigram {
struct comparator {
bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) {
return (l.score < r.score) || (l.score == r.score && l.left > r.left);
}
};
using queue_storage = std::vector<llama_sp_bigram>;
using queue = std::priority_queue<llama_sp_bigram, queue_storage, comparator>;
llama_sp_symbol::index left;
llama_sp_symbol::index right;
float score;
size_t size;
};
// original implementation:
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
struct llama_tokenizer {
llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
struct sp_symbol {
using index = int;
index prev;
index next;
const char * text;
size_t n;
};
static_assert(std::is_trivially_copyable<sp_symbol>::value, "sp_symbol is not trivially copyable");
struct sp_bigram {
struct comparator {
bool operator()(sp_bigram & l, sp_bigram & r) {
return (l.score < r.score) || (l.score == r.score && l.left > r.left);
}
};
using queue_storage = std::vector<sp_bigram>;
using queue = std::priority_queue<sp_bigram, queue_storage, comparator>;
sp_symbol::index left;
sp_symbol::index right;
float score;
size_t size;
};
llama_tokenizer(const llama_vocab & vocab): vocab(vocab) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
// split string into utf8 chars
int index = 0;
size_t offs = 0;
while (offs < text.size()) {
llama_sp_symbol sym;
sp_symbol sym;
size_t len = utf8_len(text[offs]);
GGML_ASSERT(offs + len <= text.size());
sym.text = text.c_str() + offs;
@ -2997,21 +2997,21 @@ struct llama_tokenizer {
sym.prev = index - 1;
sym.next = offs == text.size() ? -1 : index + 1;
index++;
symbols_.emplace_back(sym);
symbols.emplace_back(sym);
}
// seed the work queue with all possible 2-character tokens.
for (size_t i = 1; i < symbols_.size(); ++i) {
for (size_t i = 1; i < symbols.size(); ++i) {
try_add_bigram(i - 1, i);
}
// keep substituting the highest frequency pairs for as long as we can.
while (!work_queue_.empty()) {
auto bigram = work_queue_.top();
work_queue_.pop();
while (!work_queue.empty()) {
auto bigram = work_queue.top();
work_queue.pop();
auto & left_sym = symbols_[bigram.left];
auto & right_sym = symbols_[bigram.right];
auto & left_sym = symbols[bigram.left];
auto & right_sym = symbols[bigram.right];
// if one of the symbols already got merged, skip it.
if (left_sym.n == 0 || right_sym.n == 0 ||
@ -3028,7 +3028,7 @@ struct llama_tokenizer {
// remove the right sym from the chain
left_sym.next = right_sym.next;
if (right_sym.next >= 0) {
symbols_[right_sym.next].prev = bigram.left;
symbols[right_sym.next].prev = bigram.left;
}
// find more substitutions
@ -3036,19 +3036,19 @@ struct llama_tokenizer {
try_add_bigram(bigram.left, left_sym.next);
}
for (int i = 0; i != -1; i = symbols_[i].next) {
auto & symbol = symbols_[i];
for (int i = 0; i != -1; i = symbols[i].next) {
auto & symbol = symbols[i];
resegment(symbol, output);
}
}
private:
void resegment(llama_sp_symbol &symbol, std::vector<llama_vocab::id> &output) {
void resegment(sp_symbol & symbol, std::vector<llama_vocab::id> & output) {
auto text = std::string(symbol.text, symbol.n);
auto token = vocab_.token_to_id.find(text);
auto token = vocab.token_to_id.find(text);
// Do we need to support is_unused?
if (token != vocab_.token_to_id.end()) {
if (token != vocab.token_to_id.end()) {
output.push_back((*token).second);
return;
}
@ -3058,14 +3058,14 @@ private:
if (p == rev_merge.end()) {
// output any symbols that did not form tokens as bytes.
for (int j = 0; j < (int)symbol.n; ++j) {
llama_vocab::id token_id = llama_byte_to_token(vocab_, symbol.text[j]);
llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]);
output.push_back(token_id);
}
return;
}
resegment(symbols_[p->second.first], output);
resegment(symbols_[p->second.second], output);
resegment(symbols[p->second.first], output);
resegment(symbols[p->second.second], output);
}
void try_add_bigram(int left, int right) {
@ -3073,34 +3073,36 @@ private:
return;
}
const std::string text = std::string(symbols_[left].text, symbols_[left].n + symbols_[right].n);
auto token = vocab_.token_to_id.find(text);
const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
auto token = vocab.token_to_id.find(text);
if (token == vocab_.token_to_id.end()) {
if (token == vocab.token_to_id.end()) {
return;
}
if (static_cast<size_t>((*token).second) >= vocab_.id_to_token.size()) {
if (static_cast<size_t>((*token).second) >= vocab.id_to_token.size()) {
return;
}
const auto &tok_data = vocab_.id_to_token[(*token).second];
const auto & tok_data = vocab.id_to_token[(*token).second];
llama_sp_bigram bigram;
bigram.left = left;
sp_bigram bigram;
bigram.left = left;
bigram.right = right;
bigram.score = tok_data.score;
bigram.size = text.size();
work_queue_.push(bigram);
bigram.size = text.size();
work_queue.push(bigram);
// Do we need to support is_unused?
rev_merge[text] = std::make_pair(left, right);
}
const llama_vocab & vocab_;
std::vector<llama_sp_symbol> symbols_;
llama_sp_bigram::queue work_queue_;
std::map<std::string, std::pair<int, int> > rev_merge;
const llama_vocab & vocab;
std::vector<sp_symbol> symbols;
sp_bigram::queue work_queue;
std::map<std::string, std::pair<int, int>> rev_merge;
};
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) {