Refactor: Add llamma_ prefix in unicode.h unicode.cpp

This commit is contained in:
MichelleTPY 2024-12-14 14:18:25 +00:00
parent ba1cb19cdd
commit 0579e3bf65
2 changed files with 34 additions and 35 deletions

View file

@ -120,8 +120,8 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
// return result; // return result;
//} //}
static std::vector<codepoint_flags> unicode_cpt_flags_array() { static std::vector<llama_codepoint_flags> unicode_cpt_flags_array() {
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED); std::vector<llama_codepoint_flags> cpt_flags(MAX_CODEPOINTS, llama_codepoint_flags::LLAMA_UNDEFINED);
assert (unicode_ranges_flags.begin()[0].first == 0); assert (unicode_ranges_flags.begin()[0].first == 0);
assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS); assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
@ -253,8 +253,8 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
}; };
auto _get_flags = [&] (const size_t pos) -> codepoint_flags { auto _get_flags = [&](const size_t pos) -> llama_codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : llama_codepoint_flags{};
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@ -371,8 +371,8 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
}; };
auto _get_flags = [&] (const size_t pos) -> codepoint_flags { auto _get_flags = [&](const size_t pos) -> llama_codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : llama_codepoint_flags{};
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@ -624,14 +624,14 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result; return result;
} }
codepoint_flags unicode_cpt_flags(const uint32_t cp) { llama_codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED); static const llama_codepoint_flags undef(llama_codepoint_flags::LLAMA_UNDEFINED);
static const auto cpt_flags = unicode_cpt_flags_array(); static const auto cpt_flags = unicode_cpt_flags_array();
return cp < cpt_flags.size() ? cpt_flags[cp] : undef; return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
} }
codepoint_flags unicode_cpt_flags(const std::string & utf8) { llama_codepoint_flags unicode_cpt_flags(const std::string & utf8) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED); static const llama_codepoint_flags undef(llama_codepoint_flags::LLAMA_UNDEFINED);
if (utf8.empty()) { if (utf8.empty()) {
return undef; // undefined return undef; // undefined
} }
@ -664,21 +664,22 @@ uint32_t unicode_tolower(uint32_t cp) {
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) { std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories // unicode categories
static const std::map<std::string, int> k_ucat_enum = { static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", codepoint_flags::NUMBER }, { "\\p{N}", llama_codepoint_flags::LLAMA_NUMBER },
{ "\\p{L}", codepoint_flags::LETTER }, { "\\p{L}", llama_codepoint_flags::LLAMA_LETTER },
{ "\\p{P}", codepoint_flags::PUNCTUATION }, { "\\p{P}", llama_codepoint_flags::LLAMA_PUNCTUATION },
}; };
static const std::map<int, int> k_ucat_cpt = { static const std::map<int, int> k_ucat_cpt = {
{ codepoint_flags::NUMBER, 0xD1 }, { llama_codepoint_flags::LLAMA_NUMBER, 0xD1 },
{ codepoint_flags::LETTER, 0xD2 }, { llama_codepoint_flags::LLAMA_LETTER, 0xD2 },
{ codepoint_flags::PUNCTUATION, 0xD3 }, { llama_codepoint_flags::LLAMA_PUNCTUATION, 0xD3 },
}; };
static const std::map<int, std::string> k_ucat_map = { static const std::map<int, std::string> k_ucat_map = {
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9 { llama_codepoint_flags::LLAMA_NUMBER, "\x30-\x39" }, // 0-9
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z { llama_codepoint_flags::LLAMA_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} { llama_codepoint_flags::LLAMA_PUNCTUATION,
"\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
}; };
// compute collapsed codepoints only if needed by at least one regex // compute collapsed codepoints only if needed by at least one regex

View file

@ -4,19 +4,17 @@
#include <string> #include <string>
#include <vector> #include <vector>
// TODO: prefix all symbols with "llama_" struct llama_codepoint_flags {
struct codepoint_flags {
enum { enum {
UNDEFINED = 0x0001, LLAMA_UNDEFINED = 0x0001,
NUMBER = 0x0002, // regex: \p{N} LLAMA_NUMBER = 0x0002, // regex: \p{N}
LETTER = 0x0004, // regex: \p{L} LLAMA_LETTER = 0x0004, // regex: \p{L}
SEPARATOR = 0x0008, // regex: \p{Z} LLAMA_SEPARATOR = 0x0008, // regex: \p{Z}
ACCENT_MARK = 0x0010, // regex: \p{M} LLAMA_ACCENT_MARK = 0x0010, // regex: \p{M}
PUNCTUATION = 0x0020, // regex: \p{P} LLAMA_PUNCTUATION = 0x0020, // regex: \p{P}
SYMBOL = 0x0040, // regex: \p{S} LLAMA_SYMBOL = 0x0040, // regex: \p{S}
CONTROL = 0x0080, // regex: \p{C} LLAMA_CONTROL = 0x0080, // regex: \p{C}
MASK_CATEGORIES = 0x00FF, LLAMA_MASK_CATEGORIES = 0x00FF,
}; };
// codepoint type // codepoint type
@ -35,7 +33,7 @@ struct codepoint_flags {
uint16_t is_nfd : 1; uint16_t is_nfd : 1;
// decode from uint16 // decode from uint16
inline codepoint_flags(const uint16_t flags=0) { inline llama_codepoint_flags(const uint16_t flags = 0) {
*reinterpret_cast<uint16_t*>(this) = flags; *reinterpret_cast<uint16_t*>(this) = flags;
} }
@ -44,7 +42,7 @@ struct codepoint_flags {
} }
inline uint16_t category_flag() const { inline uint16_t category_flag() const {
return this->as_uint() & MASK_CATEGORIES; return this->as_uint() & LLAMA_MASK_CATEGORIES;
} }
}; };
@ -56,8 +54,8 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts); std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
codepoint_flags unicode_cpt_flags(const uint32_t cp); llama_codepoint_flags unicode_cpt_flags(const uint32_t cp);
codepoint_flags unicode_cpt_flags(const std::string & utf8); llama_codepoint_flags unicode_cpt_flags(const std::string & utf8);
std::string unicode_byte_to_utf8(uint8_t byte); std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8); uint8_t unicode_utf8_to_byte(const std::string & utf8);