diff --git a/unicode_regex.h b/unicode_regex.h index 40115f120..bb9a668c8 100644 --- a/unicode_regex.h +++ b/unicode_regex.h @@ -1,6 +1,7 @@ #pragma once #include "unicode.h" +#include "unordered_set" class llm_regex { public: @@ -23,7 +24,7 @@ public: auto codepoints = unicode_engine.to_codepoints(str); - for (auto & cp_1 : split_punctuation_unicode(codepoints)) { + for (auto & cp_1 : split_punctuation_unicode_ascii(codepoints)) { for (auto & cp_2 : gpt2_style_implement(cp_1)) { for (auto & cp_3 : split_digits_unicode(cp_2)) { for (auto & cp_4 : split_continuous_digits_ascii(cp_3)) { @@ -140,7 +141,7 @@ private: } // contiguous mode only - std::vector> split_punctuation_unicode(const std::vector & codepoints) { + std::vector> split_punctuation_unicode_ascii(const std::vector & codepoints) { std::vector> results; results.reserve(codepoints.size()); std::vector codepoints_buffer; @@ -152,13 +153,13 @@ private: codepoints_buffer.clear(); uint32_t codepoint = codepoints[offset]; - if (unicode_engine.is_category(codepoint, "PUNCTUATION")) { - while (offset < codepoints.size() && unicode_engine.is_category(codepoints[offset], "PUNCTUATION")) { + if (is_ascii_punctuation(codepoint) || unicode_engine.is_category(codepoint, "PUNCTUATION")) { + while (offset < codepoints.size() && (is_ascii_punctuation(codepoints[offset]) || unicode_engine.is_category(codepoints[offset], "PUNCTUATION"))) { codepoints_buffer.push_back(codepoints[offset]); offset++; } } else { - while (offset < codepoints.size() && !unicode_engine.is_category(codepoints[offset], "PUNCTUATION")) { + while (offset < codepoints.size() && !(is_ascii_punctuation(codepoints[offset]) || unicode_engine.is_category(codepoints[offset], "PUNCTUATION"))) { codepoints_buffer.push_back(codepoints[offset]); offset++; } @@ -239,4 +240,13 @@ private: return results; } + + static bool is_ascii_punctuation(const uint32_t & codepoint) { + static std::unordered_set ascii_punctuation = {33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, + 47, 58, 59, 60, 61, 62, 63, 64, 91, 92, 93, 94, 95, 96, + 123, 124, 125, 126}; + auto it = ascii_punctuation.find(codepoint); + + return it != ascii_punctuation.end(); + } };