Refactor: Add llamma_ prefix in unicode.h unicode.cpp

This commit is contained in:
MichelleTPY 2024-12-14 14:18:25 +00:00
parent ba1cb19cdd
commit 0579e3bf65
2 changed files with 34 additions and 35 deletions

View file

@ -120,8 +120,8 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
// return result;
//}
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
static std::vector<llama_codepoint_flags> unicode_cpt_flags_array() {
std::vector<llama_codepoint_flags> cpt_flags(MAX_CODEPOINTS, llama_codepoint_flags::LLAMA_UNDEFINED);
assert (unicode_ranges_flags.begin()[0].first == 0);
assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
@ -253,8 +253,8 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
auto _get_flags = [&](const size_t pos) -> llama_codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : llama_codepoint_flags{};
};
size_t _prev_end = offset_ini;
@ -371,8 +371,8 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
auto _get_flags = [&](const size_t pos) -> llama_codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : llama_codepoint_flags{};
};
size_t _prev_end = offset_ini;
@ -624,14 +624,14 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result;
}
codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
llama_codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static const llama_codepoint_flags undef(llama_codepoint_flags::LLAMA_UNDEFINED);
static const auto cpt_flags = unicode_cpt_flags_array();
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
}
codepoint_flags unicode_cpt_flags(const std::string & utf8) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
llama_codepoint_flags unicode_cpt_flags(const std::string & utf8) {
static const llama_codepoint_flags undef(llama_codepoint_flags::LLAMA_UNDEFINED);
if (utf8.empty()) {
return undef; // undefined
}
@ -664,21 +664,22 @@ uint32_t unicode_tolower(uint32_t cp) {
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories
static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", codepoint_flags::NUMBER },
{ "\\p{L}", codepoint_flags::LETTER },
{ "\\p{P}", codepoint_flags::PUNCTUATION },
{ "\\p{N}", llama_codepoint_flags::LLAMA_NUMBER },
{ "\\p{L}", llama_codepoint_flags::LLAMA_LETTER },
{ "\\p{P}", llama_codepoint_flags::LLAMA_PUNCTUATION },
};
static const std::map<int, int> k_ucat_cpt = {
{ codepoint_flags::NUMBER, 0xD1 },
{ codepoint_flags::LETTER, 0xD2 },
{ codepoint_flags::PUNCTUATION, 0xD3 },
{ llama_codepoint_flags::LLAMA_NUMBER, 0xD1 },
{ llama_codepoint_flags::LLAMA_LETTER, 0xD2 },
{ llama_codepoint_flags::LLAMA_PUNCTUATION, 0xD3 },
};
static const std::map<int, std::string> k_ucat_map = {
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
{ llama_codepoint_flags::LLAMA_NUMBER, "\x30-\x39" }, // 0-9
{ llama_codepoint_flags::LLAMA_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ llama_codepoint_flags::LLAMA_PUNCTUATION,
"\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
};
// compute collapsed codepoints only if needed by at least one regex

View file

@ -4,19 +4,17 @@
#include <string>
#include <vector>
// TODO: prefix all symbols with "llama_"
struct codepoint_flags {
struct llama_codepoint_flags {
enum {
UNDEFINED = 0x0001,
NUMBER = 0x0002, // regex: \p{N}
LETTER = 0x0004, // regex: \p{L}
SEPARATOR = 0x0008, // regex: \p{Z}
ACCENT_MARK = 0x0010, // regex: \p{M}
PUNCTUATION = 0x0020, // regex: \p{P}
SYMBOL = 0x0040, // regex: \p{S}
CONTROL = 0x0080, // regex: \p{C}
MASK_CATEGORIES = 0x00FF,
LLAMA_UNDEFINED = 0x0001,
LLAMA_NUMBER = 0x0002, // regex: \p{N}
LLAMA_LETTER = 0x0004, // regex: \p{L}
LLAMA_SEPARATOR = 0x0008, // regex: \p{Z}
LLAMA_ACCENT_MARK = 0x0010, // regex: \p{M}
LLAMA_PUNCTUATION = 0x0020, // regex: \p{P}
LLAMA_SYMBOL = 0x0040, // regex: \p{S}
LLAMA_CONTROL = 0x0080, // regex: \p{C}
LLAMA_MASK_CATEGORIES = 0x00FF,
};
// codepoint type
@ -35,7 +33,7 @@ struct codepoint_flags {
uint16_t is_nfd : 1;
// decode from uint16
inline codepoint_flags(const uint16_t flags=0) {
inline llama_codepoint_flags(const uint16_t flags = 0) {
*reinterpret_cast<uint16_t*>(this) = flags;
}
@ -44,7 +42,7 @@ struct codepoint_flags {
}
inline uint16_t category_flag() const {
return this->as_uint() & MASK_CATEGORIES;
return this->as_uint() & LLAMA_MASK_CATEGORIES;
}
};
@ -56,8 +54,8 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
codepoint_flags unicode_cpt_flags(const uint32_t cp);
codepoint_flags unicode_cpt_flags(const std::string & utf8);
llama_codepoint_flags unicode_cpt_flags(const uint32_t cp);
llama_codepoint_flags unicode_cpt_flags(const std::string & utf8);
std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8);