llama : prep new tokenizer support
This commit is contained in:
parent
6938c5f474
commit
c3f8a6e49f
1 changed files with 63 additions and 61 deletions
124
llama.cpp
124
llama.cpp
|
@ -106,6 +106,12 @@ static void llama_log_callback_default(llama_log_level level, const char * text,
|
|||
// helpers
|
||||
//
|
||||
|
||||
static size_t utf8_len(char src) {
|
||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||
return lookup[highbits];
|
||||
}
|
||||
|
||||
static void zeros(std::ofstream & file, size_t n) {
|
||||
char zero = 0;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
|
@ -2948,47 +2954,41 @@ static std::string llama_unescape_whitespace(const std::string& word) {
|
|||
return word;
|
||||
}
|
||||
|
||||
static size_t utf8_len(char src) {
|
||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||
return lookup[highbits];
|
||||
}
|
||||
|
||||
struct llama_sp_symbol {
|
||||
using index = int;
|
||||
index prev;
|
||||
index next;
|
||||
const char * text;
|
||||
size_t n;
|
||||
};
|
||||
|
||||
static_assert(std::is_trivially_copyable<llama_sp_symbol>::value, "llama_sp_symbol is not trivially copyable");
|
||||
|
||||
struct llama_sp_bigram {
|
||||
struct comparator {
|
||||
bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) {
|
||||
return (l.score < r.score) || (l.score == r.score && l.left > r.left);
|
||||
}
|
||||
};
|
||||
using queue_storage = std::vector<llama_sp_bigram>;
|
||||
using queue = std::priority_queue<llama_sp_bigram, queue_storage, comparator>;
|
||||
llama_sp_symbol::index left;
|
||||
llama_sp_symbol::index right;
|
||||
float score;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
// original implementation:
|
||||
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
||||
struct llama_tokenizer {
|
||||
llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
|
||||
struct sp_symbol {
|
||||
using index = int;
|
||||
index prev;
|
||||
index next;
|
||||
const char * text;
|
||||
size_t n;
|
||||
};
|
||||
|
||||
static_assert(std::is_trivially_copyable<sp_symbol>::value, "sp_symbol is not trivially copyable");
|
||||
|
||||
struct sp_bigram {
|
||||
struct comparator {
|
||||
bool operator()(sp_bigram & l, sp_bigram & r) {
|
||||
return (l.score < r.score) || (l.score == r.score && l.left > r.left);
|
||||
}
|
||||
};
|
||||
using queue_storage = std::vector<sp_bigram>;
|
||||
using queue = std::priority_queue<sp_bigram, queue_storage, comparator>;
|
||||
sp_symbol::index left;
|
||||
sp_symbol::index right;
|
||||
float score;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
llama_tokenizer(const llama_vocab & vocab): vocab(vocab) {}
|
||||
|
||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||
// split string into utf8 chars
|
||||
int index = 0;
|
||||
size_t offs = 0;
|
||||
while (offs < text.size()) {
|
||||
llama_sp_symbol sym;
|
||||
sp_symbol sym;
|
||||
size_t len = utf8_len(text[offs]);
|
||||
GGML_ASSERT(offs + len <= text.size());
|
||||
sym.text = text.c_str() + offs;
|
||||
|
@ -2997,21 +2997,21 @@ struct llama_tokenizer {
|
|||
sym.prev = index - 1;
|
||||
sym.next = offs == text.size() ? -1 : index + 1;
|
||||
index++;
|
||||
symbols_.emplace_back(sym);
|
||||
symbols.emplace_back(sym);
|
||||
}
|
||||
|
||||
// seed the work queue with all possible 2-character tokens.
|
||||
for (size_t i = 1; i < symbols_.size(); ++i) {
|
||||
for (size_t i = 1; i < symbols.size(); ++i) {
|
||||
try_add_bigram(i - 1, i);
|
||||
}
|
||||
|
||||
// keep substituting the highest frequency pairs for as long as we can.
|
||||
while (!work_queue_.empty()) {
|
||||
auto bigram = work_queue_.top();
|
||||
work_queue_.pop();
|
||||
while (!work_queue.empty()) {
|
||||
auto bigram = work_queue.top();
|
||||
work_queue.pop();
|
||||
|
||||
auto & left_sym = symbols_[bigram.left];
|
||||
auto & right_sym = symbols_[bigram.right];
|
||||
auto & left_sym = symbols[bigram.left];
|
||||
auto & right_sym = symbols[bigram.right];
|
||||
|
||||
// if one of the symbols already got merged, skip it.
|
||||
if (left_sym.n == 0 || right_sym.n == 0 ||
|
||||
|
@ -3028,7 +3028,7 @@ struct llama_tokenizer {
|
|||
// remove the right sym from the chain
|
||||
left_sym.next = right_sym.next;
|
||||
if (right_sym.next >= 0) {
|
||||
symbols_[right_sym.next].prev = bigram.left;
|
||||
symbols[right_sym.next].prev = bigram.left;
|
||||
}
|
||||
|
||||
// find more substitutions
|
||||
|
@ -3036,19 +3036,19 @@ struct llama_tokenizer {
|
|||
try_add_bigram(bigram.left, left_sym.next);
|
||||
}
|
||||
|
||||
for (int i = 0; i != -1; i = symbols_[i].next) {
|
||||
auto & symbol = symbols_[i];
|
||||
for (int i = 0; i != -1; i = symbols[i].next) {
|
||||
auto & symbol = symbols[i];
|
||||
resegment(symbol, output);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void resegment(llama_sp_symbol &symbol, std::vector<llama_vocab::id> &output) {
|
||||
void resegment(sp_symbol & symbol, std::vector<llama_vocab::id> & output) {
|
||||
auto text = std::string(symbol.text, symbol.n);
|
||||
auto token = vocab_.token_to_id.find(text);
|
||||
auto token = vocab.token_to_id.find(text);
|
||||
|
||||
// Do we need to support is_unused?
|
||||
if (token != vocab_.token_to_id.end()) {
|
||||
if (token != vocab.token_to_id.end()) {
|
||||
output.push_back((*token).second);
|
||||
return;
|
||||
}
|
||||
|
@ -3058,14 +3058,14 @@ private:
|
|||
if (p == rev_merge.end()) {
|
||||
// output any symbols that did not form tokens as bytes.
|
||||
for (int j = 0; j < (int)symbol.n; ++j) {
|
||||
llama_vocab::id token_id = llama_byte_to_token(vocab_, symbol.text[j]);
|
||||
llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]);
|
||||
output.push_back(token_id);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
resegment(symbols_[p->second.first], output);
|
||||
resegment(symbols_[p->second.second], output);
|
||||
resegment(symbols[p->second.first], output);
|
||||
resegment(symbols[p->second.second], output);
|
||||
}
|
||||
|
||||
void try_add_bigram(int left, int right) {
|
||||
|
@ -3073,34 +3073,36 @@ private:
|
|||
return;
|
||||
}
|
||||
|
||||
const std::string text = std::string(symbols_[left].text, symbols_[left].n + symbols_[right].n);
|
||||
auto token = vocab_.token_to_id.find(text);
|
||||
const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
|
||||
auto token = vocab.token_to_id.find(text);
|
||||
|
||||
if (token == vocab_.token_to_id.end()) {
|
||||
if (token == vocab.token_to_id.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (static_cast<size_t>((*token).second) >= vocab_.id_to_token.size()) {
|
||||
if (static_cast<size_t>((*token).second) >= vocab.id_to_token.size()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto &tok_data = vocab_.id_to_token[(*token).second];
|
||||
const auto & tok_data = vocab.id_to_token[(*token).second];
|
||||
|
||||
llama_sp_bigram bigram;
|
||||
bigram.left = left;
|
||||
sp_bigram bigram;
|
||||
bigram.left = left;
|
||||
bigram.right = right;
|
||||
bigram.score = tok_data.score;
|
||||
bigram.size = text.size();
|
||||
work_queue_.push(bigram);
|
||||
bigram.size = text.size();
|
||||
|
||||
work_queue.push(bigram);
|
||||
|
||||
// Do we need to support is_unused?
|
||||
rev_merge[text] = std::make_pair(left, right);
|
||||
}
|
||||
|
||||
const llama_vocab & vocab_;
|
||||
std::vector<llama_sp_symbol> symbols_;
|
||||
llama_sp_bigram::queue work_queue_;
|
||||
std::map<std::string, std::pair<int, int> > rev_merge;
|
||||
const llama_vocab & vocab;
|
||||
|
||||
std::vector<sp_symbol> symbols;
|
||||
sp_bigram::queue work_queue;
|
||||
std::map<std::string, std::pair<int, int>> rev_merge;
|
||||
};
|
||||
|
||||
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue