Fix punctuation split

This commit is contained in:
bobqianic 2024-02-20 19:03:51 +00:00 committed by GitHub
parent 93e2c73dba
commit 1b8da8e0a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,6 +1,7 @@
#pragma once
#include "unicode.h"
#include "unordered_set"
class llm_regex {
public:
@ -23,7 +24,7 @@ public:
auto codepoints = unicode_engine.to_codepoints(str);
for (auto & cp_1 : split_punctuation_unicode(codepoints)) {
for (auto & cp_1 : split_punctuation_unicode_ascii(codepoints)) {
for (auto & cp_2 : gpt2_style_implement(cp_1)) {
for (auto & cp_3 : split_digits_unicode(cp_2)) {
for (auto & cp_4 : split_continuous_digits_ascii(cp_3)) {
@ -140,7 +141,7 @@ private:
}
// contiguous mode only
std::vector<std::vector<uint32_t>> split_punctuation_unicode(const std::vector<uint32_t> & codepoints) {
std::vector<std::vector<uint32_t>> split_punctuation_unicode_ascii(const std::vector<uint32_t> & codepoints) {
std::vector<std::vector<uint32_t>> results;
results.reserve(codepoints.size());
std::vector<uint32_t> codepoints_buffer;
@ -152,13 +153,13 @@ private:
codepoints_buffer.clear();
uint32_t codepoint = codepoints[offset];
if (unicode_engine.is_category(codepoint, "PUNCTUATION")) {
while (offset < codepoints.size() && unicode_engine.is_category(codepoints[offset], "PUNCTUATION")) {
if (is_ascii_punctuation(codepoint) || unicode_engine.is_category(codepoint, "PUNCTUATION")) {
while (offset < codepoints.size() && (is_ascii_punctuation(codepoints[offset]) || unicode_engine.is_category(codepoints[offset], "PUNCTUATION"))) {
codepoints_buffer.push_back(codepoints[offset]);
offset++;
}
} else {
while (offset < codepoints.size() && !unicode_engine.is_category(codepoints[offset], "PUNCTUATION")) {
while (offset < codepoints.size() && !(is_ascii_punctuation(codepoints[offset]) || unicode_engine.is_category(codepoints[offset], "PUNCTUATION"))) {
codepoints_buffer.push_back(codepoints[offset]);
offset++;
}
@ -239,4 +240,13 @@ private:
return results;
}
static bool is_ascii_punctuation(const uint32_t & codepoint) {
static std::unordered_set<uint32_t> ascii_punctuation = {33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46,
47, 58, 59, 60, 61, 62, 63, 64, 91, 92, 93, 94, 95, 96,
123, 124, 125, 126};
auto it = ascii_punctuation.find(codepoint);
return it != ascii_punctuation.end();
}
};