This commit is contained in:
bobqianic 2024-02-26 11:07:34 +00:00 committed by GitHub
commit d90875a8f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 1590 additions and 611 deletions

150
llama.cpp
View file

@ -1,7 +1,8 @@
#define LLAMA_API_INTERNAL
#include "llama.h"
#include "unicode.h"
// #include "unicode.h"
#include "unicode_regex.h"
#include "ggml.h"
#include "ggml-alloc.h"
@ -114,6 +115,13 @@ static void llama_log_callback_default(ggml_log_level level, const char * text,
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
//
// unicode utilities
//
static llm_regex regex_engine;
auto unicode_engine = regex_engine.get_unicode_engine();
//
// helpers
//
@ -3388,7 +3396,7 @@ static void llm_load_vocab(
for (int i = 0; i < n_merges; i++) {
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
GGML_ASSERT(codepoints_from_utf8(word).size() > 0);
GGML_ASSERT(unicode_engine.to_codepoints(word).size() > 0);
std::string first;
std::string second;
@ -3433,7 +3441,7 @@ static void llm_load_vocab(
for (uint32_t i = 0; i < n_vocab; i++) {
std::string word = gguf_get_arr_str(ctx, token_idx, i);
GGML_ASSERT(codepoints_from_utf8(word).size() > 0);
GGML_ASSERT(unicode_engine.to_codepoints(word).size() > 0);
vocab.token_to_id[word] = i;
@ -8344,7 +8352,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
}
case LLAMA_VOCAB_TYPE_BPE: {
GGML_ASSERT(false);
return unicode_to_bytes_bpe(token_data.text);
return unicode_engine.unicode_to_bytes_bpe(token_data.text);
}
case LLAMA_VOCAB_TYPE_WPM: {
GGML_ASSERT(false);
@ -8369,7 +8377,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
}
case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_BPE: {
return vocab.token_to_id.at(bytes_to_unicode_bpe(ch));
return vocab.token_to_id.at(unicode_engine.bytes_to_unicode_bpe(ch));
}
default:
GGML_ASSERT(false);
@ -8693,137 +8701,13 @@ private:
}
std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
std::vector<std::string> bpe_words;
std::vector<std::string> bpe_encoded_words;
std::string token = "";
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
bool collecting_numeric = false;
bool collecting_letter = false;
bool collecting_special = false;
bool collecting_whitespace_lookahead = false;
bool collecting = false;
std::vector<std::string> text_utf;
text_utf.reserve(text.size());
bpe_words.reserve(text.size());
bpe_encoded_words.reserve(text.size());
auto cps = codepoints_from_utf8(text);
for (size_t i = 0; i < cps.size(); ++i)
text_utf.emplace_back(codepoint_to_utf8(cps[i]));
for (int i = 0; i < (int)text_utf.size(); i++) {
const std::string & utf_char = text_utf[i];
bool split_condition = false;
int bytes_remain = text_utf.size() - i;
// forward backward lookups
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
// handling contractions
if (!split_condition && bytes_remain >= 2) {
// 's|'t|'m|'d
if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
split_condition = true;
}
if (split_condition) {
if (token.size()) {
bpe_words.emplace_back(token); // push previous content as token
}
token = utf_char + utf_char_next;
bpe_words.emplace_back(token);
token = "";
i++;
continue;
}
}
if (!split_condition && bytes_remain >= 3) {
// 're|'ve|'ll
if (utf_char == "\'" && (
(utf_char_next == "r" && utf_char_next_next == "e") ||
(utf_char_next == "v" && utf_char_next_next == "e") ||
(utf_char_next == "l" && utf_char_next_next == "l"))
) {
split_condition = true;
}
if (split_condition) {
// current token + next token can be defined
if (token.size()) {
bpe_words.emplace_back(token); // push previous content as token
}
token = utf_char + utf_char_next + utf_char_next_next;
bpe_words.emplace_back(token); // the contraction
token = "";
i += 2;
continue;
}
}
if (!split_condition && !collecting) {
if (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
collecting_letter = true;
collecting = true;
}
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
collecting_numeric = true;
collecting = true;
}
else if (
((codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (codepoint_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
(!token.size() && utf_char == " " && codepoint_type(utf_char_next) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
) {
collecting_special = true;
collecting = true;
}
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && codepoint_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
collecting_whitespace_lookahead = true;
collecting = true;
}
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
split_condition = true;
}
}
else if (!split_condition && collecting) {
if (collecting_letter && codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER) {
split_condition = true;
}
else if (collecting_numeric && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
split_condition = true;
}
else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
split_condition = true;
}
else if (collecting_whitespace_lookahead && (codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
split_condition = true;
}
}
if (utf_char_next == "") {
split_condition = true; // final
token += utf_char;
}
if (split_condition) {
if (token.size()) {
bpe_words.emplace_back(token);
}
token = utf_char;
collecting = false;
collecting_letter = false;
collecting_numeric = false;
collecting_special = false;
collecting_whitespace_lookahead = false;
}
else {
token += utf_char;
}
}
for (std::string & word : bpe_words) {
for (std::string & word : regex_engine.falcon_style(text)) {
std::string encoded_token = "";
for (char & c : word) {
encoded_token += bytes_to_unicode_bpe(c);
encoded_token += unicode_engine.bytes_to_unicode_bpe(c);
}
bpe_encoded_words.emplace_back(encoded_token);
}
@ -12964,9 +12848,9 @@ int32_t llama_tokenize(
static std::string llama_decode_text(const std::string & text) {
std::string decoded_text;
auto unicode_sequences = codepoints_from_utf8(text);
auto unicode_sequences = unicode_engine.to_codepoints(text);
for (auto& unicode_sequence : unicode_sequences) {
decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence));
decoded_text += unicode_engine.unicode_to_bytes_bpe(unicode_engine.to_string(unicode_sequence));
}
return decoded_text;

View file

@ -38,6 +38,7 @@ static const std::map<std::string, std::vector<llama_token>> & k_tests() {
{ " Hello\n Hello" , { 466, 23090, 742, 23090, }, },
{ "\n =" , { 1212, 40, }, },
{ "' era" , { 18, 4932, }, },
{ "12345678-1239-0fsjk" , { 10963, 27681, 5070, 24, 10963, 36, 24, 27, 5577, 85, 86, }, },
};
return _k_tests;

View file

@ -64,7 +64,7 @@ int main(int argc, char **argv) {
for (int i = 0; i < n_vocab; ++i) {
std::string str = llama_detokenize_bpe(ctx, std::vector<int>(1, i));
try {
auto cps = codepoints_from_utf8(str);
auto cps = (str);
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
std::string check = llama_detokenize_bpe(ctx, tokens);
if (check != str) {
@ -80,6 +80,7 @@ int main(int argc, char **argv) {
// unicode
{
static UNICODE unicode_engine;
const int nthread = std::thread::hardware_concurrency();
std::vector<std::thread> threads(nthread);
@ -97,7 +98,7 @@ int main(int argc, char **argv) {
continue;
}
std::string str = codepoint_to_utf8(cp);
std::string str = unicode_engine.to_string(cp);
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
std::string check = llama_detokenize_bpe(ctx, tokens);
if (cp != 9601 && str != check) {

View file

@ -74,6 +74,7 @@ int main(int argc, char **argv) {
// unicode
{
static UNICODE unicode_engine;
const int nthread = std::thread::hardware_concurrency();
std::vector<std::thread> threads(nthread);
@ -85,7 +86,7 @@ int main(int argc, char **argv) {
continue;
}
std::string str = codepoint_to_utf8(cp);
std::string str = unicode_engine.to_string(cp);
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
std::string check = llama_detokenize_spm(ctx, tokens);
if (cp != 9601 && str != check) {

1790
unicode.h

File diff suppressed because it is too large Load diff

252
unicode_regex.h Normal file
View file

@ -0,0 +1,252 @@
#pragma once
#include "unicode.h"
#include "unordered_set"
class llm_regex {
public:
std::vector<std::string> gpt2_style(const std::string & str) {
std::vector<std::string> results;
results.reserve(str.size());
auto codepoints = unicode_engine.to_codepoints(str);
for (auto & cp : gpt2_style_implement(codepoints)) {
results.push_back(unicode_engine.to_string(cp));
}
return results;
}
std::vector<std::string> falcon_style(const std::string & str) {
std::vector<std::string> results;
results.reserve(str.size());
auto codepoints = unicode_engine.to_codepoints(str);
for (auto & cp_1 : split_punctuation_unicode_ascii(codepoints)) {
for (auto & cp_2 : gpt2_style_implement(cp_1)) {
for (auto & cp_3 : split_digits_unicode(cp_2)) {
for (auto & cp_4 : split_continuous_digits_ascii(cp_3)) {
results.push_back(unicode_engine.to_string(cp_4));
}
}
}
}
return results;
}
UNICODE & get_unicode_engine() {
return unicode_engine;
}
llm_regex() {
unicode_engine.overload_category(REGEX_RANGES::Whitespace, "WHITESPACE");
}
private:
UNICODE unicode_engine;
// Very basic match no metacharacter support
static bool basic_match(const std::vector<std::vector<uint32_t>> & codepoint_rules,
const std::vector<uint32_t> & codepoints,
std::vector<std::vector<uint32_t>> & output,
size_t & offset) {
for (auto & codepoint_rule : codepoint_rules) {
bool satisfy = true;
for (size_t ru_index = 0; ru_index < codepoint_rule.size(); ru_index++) {
if (offset + ru_index >= codepoints.size() || codepoint_rule[ru_index] != codepoints[offset + ru_index]) {
satisfy = false;
break;
}
}
if (satisfy) {
output.push_back(codepoint_rule);
offset += codepoint_rule.size();
return true;
}
}
return false;
}
// "behavior": "Isolated"
// separate any continuous digits longer than 2
static std::vector<std::vector<uint32_t>> split_continuous_digits_ascii(const std::vector<uint32_t> & codepoints) {
std::vector<std::vector<uint32_t>> results;
results.reserve(codepoints.size());
std::vector<uint32_t> codepoints_buffer;
codepoints_buffer.reserve(codepoints.size());
size_t offset = 0;
while (offset < codepoints.size()) {
codepoints_buffer.clear();
uint32_t codepoint = codepoints[offset];
uint32_t counter = 0;
if (codepoint >= 48 && codepoint <= 57) {
while (offset < codepoints.size() && codepoints[offset] >= 48 && codepoints[offset] <= 57 && counter < 3) {
codepoints_buffer.push_back(codepoints[offset]);
offset++;
counter++;
}
} else {
while (offset < codepoints.size() && (codepoints[offset] < 48 || codepoints[offset] > 57)) {
codepoints_buffer.push_back(codepoints[offset]);
offset++;
}
}
if (!codepoints_buffer.empty()) {
results.push_back(codepoints_buffer);
}
}
return results;
}
// "individual_digits": false
std::vector<std::vector<uint32_t>> split_digits_unicode(const std::vector<uint32_t> & codepoints) {
std::vector<std::vector<uint32_t>> results;
results.reserve(codepoints.size());
std::vector<uint32_t> codepoints_buffer;
codepoints_buffer.reserve(codepoints.size());
size_t offset = 0;
while (offset < codepoints.size()) {
codepoints_buffer.clear();
uint32_t codepoint = codepoints[offset];
if (unicode_engine.is_category(codepoint, "NUMBER")) {
while (offset < codepoints.size() && unicode_engine.is_category(codepoints[offset], "NUMBER")) {
codepoints_buffer.push_back(codepoints[offset]);
offset++;
}
} else {
while (offset < codepoints.size() && !unicode_engine.is_category(codepoints[offset], "NUMBER")) {
codepoints_buffer.push_back(codepoints[offset]);
offset++;
}
}
if (!codepoints_buffer.empty()) {
results.push_back(codepoints_buffer);
}
}
return results;
}
// contiguous mode only
std::vector<std::vector<uint32_t>> split_punctuation_unicode_ascii(const std::vector<uint32_t> & codepoints) {
std::vector<std::vector<uint32_t>> results;
results.reserve(codepoints.size());
std::vector<uint32_t> codepoints_buffer;
codepoints_buffer.reserve(codepoints.size());
size_t offset = 0;
while (offset < codepoints.size()) {
codepoints_buffer.clear();
uint32_t codepoint = codepoints[offset];
if (is_ascii_punctuation(codepoint) || unicode_engine.is_category(codepoint, "PUNCTUATION")) {
while (offset < codepoints.size() && (is_ascii_punctuation(codepoints[offset]) || unicode_engine.is_category(codepoints[offset], "PUNCTUATION"))) {
codepoints_buffer.push_back(codepoints[offset]);
offset++;
}
} else {
while (offset < codepoints.size() && !(is_ascii_punctuation(codepoints[offset]) || unicode_engine.is_category(codepoints[offset], "PUNCTUATION"))) {
codepoints_buffer.push_back(codepoints[offset]);
offset++;
}
}
if (!codepoints_buffer.empty()) {
results.push_back(codepoints_buffer);
}
}
return results;
}
std::vector<std::vector<uint32_t>> gpt2_style_implement(const std::vector<uint32_t> & codepoints) {
std::vector<std::vector<uint32_t>> results;
results.reserve(codepoints.size());
std::vector<uint32_t> codepoints_buffer;
codepoints_buffer.reserve(codepoints.size());
size_t offset = 0;
static auto codepoint_rules_1 = unicode_engine.to_codepoints({"'s", "'t", "'re", "'ve", "'m", "ll", "'d"});
static auto codepoint_rules_2 = unicode_engine.to_category_code({"WHITESPACE", "LETTER", "NUMBER"});
while (offset < codepoints.size()) {
codepoints_buffer.clear();
uint32_t codepoint = codepoints[offset];
uint32_t codepoint_next = (offset + 1 < codepoints.size()) ? codepoints[offset + 1] : 0xFFFFFFFF;
//'s|'t|'re|'ve|'m|'ll|'d
if (basic_match(codepoint_rules_1, codepoints, results, offset)) {
continue;
}
// ?\p{L}+
else if (unicode_engine.is_category(codepoint, "LETTER") || (codepoint == 32 && unicode_engine.is_category(codepoint_next, "LETTER"))) {
codepoints_buffer.push_back(codepoint);
offset++;
while (offset < codepoints.size() && unicode_engine.is_category(codepoints[offset], "LETTER")) {
codepoints_buffer.push_back(codepoints[offset]);
offset++;
}
}
// ?\p{N}+
else if (unicode_engine.is_category(codepoint, "NUMBER") || (codepoint == 32 && unicode_engine.is_category(codepoint_next, "NUMBER"))) {
codepoints_buffer.push_back(codepoint);
offset++;
while (offset < codepoints.size() && unicode_engine.is_category(codepoints[offset], "NUMBER")) {
codepoints_buffer.push_back(codepoints[offset]);
offset++;
}
}
// ?[^\s\p{L}\p{N}]+
else if (!unicode_engine.is_category(codepoint, codepoint_rules_2) || (codepoint == 32 && !unicode_engine.is_category(codepoint_next, codepoint_rules_2))) {
codepoints_buffer.push_back(codepoint);
offset++;
while (offset < codepoints.size() && !unicode_engine.is_category(codepoints[offset], codepoint_rules_2)) {
codepoints_buffer.push_back(codepoints[offset]);
offset++;
}
}
//\s+(?!\S)|\s+
else if (unicode_engine.is_category(codepoint, "WHITESPACE")) {
codepoints_buffer.push_back(codepoint);
offset++;
while (offset < codepoints.size() && unicode_engine.is_category(codepoints[offset], "WHITESPACE")) {
if (offset + 1 < codepoints.size() && !unicode_engine.is_category(codepoints[offset+1], "WHITESPACE")) { break;}
codepoints_buffer.push_back(codepoints[offset]);
offset++;
}
} else {
offset++;
}
if (!codepoints_buffer.empty()) {
results.push_back(codepoints_buffer);
}
}
return results;
}
static bool is_ascii_punctuation(const uint32_t & codepoint) {
static std::unordered_set<uint32_t> ascii_punctuation = {33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46,
47, 58, 59, 60, 61, 62, 63, 64, 91, 92, 93, 94, 95, 96,
123, 124, 125, 126};
auto it = ascii_punctuation.find(codepoint);
return it != ascii_punctuation.end();
}
};