llama : prep new tokenizer support
This commit is contained in:
parent
6938c5f474
commit
c3f8a6e49f
1 changed files with 63 additions and 61 deletions
124
llama.cpp
124
llama.cpp
|
@ -106,6 +106,12 @@ static void llama_log_callback_default(llama_log_level level, const char * text,
|
||||||
// helpers
|
// helpers
|
||||||
//
|
//
|
||||||
|
|
||||||
|
static size_t utf8_len(char src) {
|
||||||
|
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
|
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||||
|
return lookup[highbits];
|
||||||
|
}
|
||||||
|
|
||||||
static void zeros(std::ofstream & file, size_t n) {
|
static void zeros(std::ofstream & file, size_t n) {
|
||||||
char zero = 0;
|
char zero = 0;
|
||||||
for (size_t i = 0; i < n; ++i) {
|
for (size_t i = 0; i < n; ++i) {
|
||||||
|
@ -2948,47 +2954,41 @@ static std::string llama_unescape_whitespace(const std::string& word) {
|
||||||
return word;
|
return word;
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t utf8_len(char src) {
|
|
||||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
|
||||||
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
|
||||||
return lookup[highbits];
|
|
||||||
}
|
|
||||||
|
|
||||||
struct llama_sp_symbol {
|
|
||||||
using index = int;
|
|
||||||
index prev;
|
|
||||||
index next;
|
|
||||||
const char * text;
|
|
||||||
size_t n;
|
|
||||||
};
|
|
||||||
|
|
||||||
static_assert(std::is_trivially_copyable<llama_sp_symbol>::value, "llama_sp_symbol is not trivially copyable");
|
|
||||||
|
|
||||||
struct llama_sp_bigram {
|
|
||||||
struct comparator {
|
|
||||||
bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) {
|
|
||||||
return (l.score < r.score) || (l.score == r.score && l.left > r.left);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
using queue_storage = std::vector<llama_sp_bigram>;
|
|
||||||
using queue = std::priority_queue<llama_sp_bigram, queue_storage, comparator>;
|
|
||||||
llama_sp_symbol::index left;
|
|
||||||
llama_sp_symbol::index right;
|
|
||||||
float score;
|
|
||||||
size_t size;
|
|
||||||
};
|
|
||||||
|
|
||||||
// original implementation:
|
// original implementation:
|
||||||
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
||||||
struct llama_tokenizer {
|
struct llama_tokenizer {
|
||||||
llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
|
struct sp_symbol {
|
||||||
|
using index = int;
|
||||||
|
index prev;
|
||||||
|
index next;
|
||||||
|
const char * text;
|
||||||
|
size_t n;
|
||||||
|
};
|
||||||
|
|
||||||
|
static_assert(std::is_trivially_copyable<sp_symbol>::value, "sp_symbol is not trivially copyable");
|
||||||
|
|
||||||
|
struct sp_bigram {
|
||||||
|
struct comparator {
|
||||||
|
bool operator()(sp_bigram & l, sp_bigram & r) {
|
||||||
|
return (l.score < r.score) || (l.score == r.score && l.left > r.left);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
using queue_storage = std::vector<sp_bigram>;
|
||||||
|
using queue = std::priority_queue<sp_bigram, queue_storage, comparator>;
|
||||||
|
sp_symbol::index left;
|
||||||
|
sp_symbol::index right;
|
||||||
|
float score;
|
||||||
|
size_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
llama_tokenizer(const llama_vocab & vocab): vocab(vocab) {}
|
||||||
|
|
||||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||||
// split string into utf8 chars
|
// split string into utf8 chars
|
||||||
int index = 0;
|
int index = 0;
|
||||||
size_t offs = 0;
|
size_t offs = 0;
|
||||||
while (offs < text.size()) {
|
while (offs < text.size()) {
|
||||||
llama_sp_symbol sym;
|
sp_symbol sym;
|
||||||
size_t len = utf8_len(text[offs]);
|
size_t len = utf8_len(text[offs]);
|
||||||
GGML_ASSERT(offs + len <= text.size());
|
GGML_ASSERT(offs + len <= text.size());
|
||||||
sym.text = text.c_str() + offs;
|
sym.text = text.c_str() + offs;
|
||||||
|
@ -2997,21 +2997,21 @@ struct llama_tokenizer {
|
||||||
sym.prev = index - 1;
|
sym.prev = index - 1;
|
||||||
sym.next = offs == text.size() ? -1 : index + 1;
|
sym.next = offs == text.size() ? -1 : index + 1;
|
||||||
index++;
|
index++;
|
||||||
symbols_.emplace_back(sym);
|
symbols.emplace_back(sym);
|
||||||
}
|
}
|
||||||
|
|
||||||
// seed the work queue with all possible 2-character tokens.
|
// seed the work queue with all possible 2-character tokens.
|
||||||
for (size_t i = 1; i < symbols_.size(); ++i) {
|
for (size_t i = 1; i < symbols.size(); ++i) {
|
||||||
try_add_bigram(i - 1, i);
|
try_add_bigram(i - 1, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
// keep substituting the highest frequency pairs for as long as we can.
|
// keep substituting the highest frequency pairs for as long as we can.
|
||||||
while (!work_queue_.empty()) {
|
while (!work_queue.empty()) {
|
||||||
auto bigram = work_queue_.top();
|
auto bigram = work_queue.top();
|
||||||
work_queue_.pop();
|
work_queue.pop();
|
||||||
|
|
||||||
auto & left_sym = symbols_[bigram.left];
|
auto & left_sym = symbols[bigram.left];
|
||||||
auto & right_sym = symbols_[bigram.right];
|
auto & right_sym = symbols[bigram.right];
|
||||||
|
|
||||||
// if one of the symbols already got merged, skip it.
|
// if one of the symbols already got merged, skip it.
|
||||||
if (left_sym.n == 0 || right_sym.n == 0 ||
|
if (left_sym.n == 0 || right_sym.n == 0 ||
|
||||||
|
@ -3028,7 +3028,7 @@ struct llama_tokenizer {
|
||||||
// remove the right sym from the chain
|
// remove the right sym from the chain
|
||||||
left_sym.next = right_sym.next;
|
left_sym.next = right_sym.next;
|
||||||
if (right_sym.next >= 0) {
|
if (right_sym.next >= 0) {
|
||||||
symbols_[right_sym.next].prev = bigram.left;
|
symbols[right_sym.next].prev = bigram.left;
|
||||||
}
|
}
|
||||||
|
|
||||||
// find more substitutions
|
// find more substitutions
|
||||||
|
@ -3036,19 +3036,19 @@ struct llama_tokenizer {
|
||||||
try_add_bigram(bigram.left, left_sym.next);
|
try_add_bigram(bigram.left, left_sym.next);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i != -1; i = symbols_[i].next) {
|
for (int i = 0; i != -1; i = symbols[i].next) {
|
||||||
auto & symbol = symbols_[i];
|
auto & symbol = symbols[i];
|
||||||
resegment(symbol, output);
|
resegment(symbol, output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void resegment(llama_sp_symbol &symbol, std::vector<llama_vocab::id> &output) {
|
void resegment(sp_symbol & symbol, std::vector<llama_vocab::id> & output) {
|
||||||
auto text = std::string(symbol.text, symbol.n);
|
auto text = std::string(symbol.text, symbol.n);
|
||||||
auto token = vocab_.token_to_id.find(text);
|
auto token = vocab.token_to_id.find(text);
|
||||||
|
|
||||||
// Do we need to support is_unused?
|
// Do we need to support is_unused?
|
||||||
if (token != vocab_.token_to_id.end()) {
|
if (token != vocab.token_to_id.end()) {
|
||||||
output.push_back((*token).second);
|
output.push_back((*token).second);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -3058,14 +3058,14 @@ private:
|
||||||
if (p == rev_merge.end()) {
|
if (p == rev_merge.end()) {
|
||||||
// output any symbols that did not form tokens as bytes.
|
// output any symbols that did not form tokens as bytes.
|
||||||
for (int j = 0; j < (int)symbol.n; ++j) {
|
for (int j = 0; j < (int)symbol.n; ++j) {
|
||||||
llama_vocab::id token_id = llama_byte_to_token(vocab_, symbol.text[j]);
|
llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]);
|
||||||
output.push_back(token_id);
|
output.push_back(token_id);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
resegment(symbols_[p->second.first], output);
|
resegment(symbols[p->second.first], output);
|
||||||
resegment(symbols_[p->second.second], output);
|
resegment(symbols[p->second.second], output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void try_add_bigram(int left, int right) {
|
void try_add_bigram(int left, int right) {
|
||||||
|
@ -3073,34 +3073,36 @@ private:
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string text = std::string(symbols_[left].text, symbols_[left].n + symbols_[right].n);
|
const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
|
||||||
auto token = vocab_.token_to_id.find(text);
|
auto token = vocab.token_to_id.find(text);
|
||||||
|
|
||||||
if (token == vocab_.token_to_id.end()) {
|
if (token == vocab.token_to_id.end()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (static_cast<size_t>((*token).second) >= vocab_.id_to_token.size()) {
|
if (static_cast<size_t>((*token).second) >= vocab.id_to_token.size()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto &tok_data = vocab_.id_to_token[(*token).second];
|
const auto & tok_data = vocab.id_to_token[(*token).second];
|
||||||
|
|
||||||
llama_sp_bigram bigram;
|
sp_bigram bigram;
|
||||||
bigram.left = left;
|
bigram.left = left;
|
||||||
bigram.right = right;
|
bigram.right = right;
|
||||||
bigram.score = tok_data.score;
|
bigram.score = tok_data.score;
|
||||||
bigram.size = text.size();
|
bigram.size = text.size();
|
||||||
work_queue_.push(bigram);
|
|
||||||
|
work_queue.push(bigram);
|
||||||
|
|
||||||
// Do we need to support is_unused?
|
// Do we need to support is_unused?
|
||||||
rev_merge[text] = std::make_pair(left, right);
|
rev_merge[text] = std::make_pair(left, right);
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_vocab & vocab_;
|
const llama_vocab & vocab;
|
||||||
std::vector<llama_sp_symbol> symbols_;
|
|
||||||
llama_sp_bigram::queue work_queue_;
|
std::vector<sp_symbol> symbols;
|
||||||
std::map<std::string, std::pair<int, int> > rev_merge;
|
sp_bigram::queue work_queue;
|
||||||
|
std::map<std::string, std::pair<int, int>> rev_merge;
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) {
|
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue