Fix punctuation split
This commit is contained in:
parent
93e2c73dba
commit
1b8da8e0a6
1 changed files with 15 additions and 5 deletions
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "unicode.h"
|
||||
#include "unordered_set"
|
||||
|
||||
class llm_regex {
|
||||
public:
|
||||
|
@ -23,7 +24,7 @@ public:
|
|||
|
||||
auto codepoints = unicode_engine.to_codepoints(str);
|
||||
|
||||
for (auto & cp_1 : split_punctuation_unicode(codepoints)) {
|
||||
for (auto & cp_1 : split_punctuation_unicode_ascii(codepoints)) {
|
||||
for (auto & cp_2 : gpt2_style_implement(cp_1)) {
|
||||
for (auto & cp_3 : split_digits_unicode(cp_2)) {
|
||||
for (auto & cp_4 : split_continuous_digits_ascii(cp_3)) {
|
||||
|
@ -140,7 +141,7 @@ private:
|
|||
}
|
||||
|
||||
// contiguous mode only
|
||||
std::vector<std::vector<uint32_t>> split_punctuation_unicode(const std::vector<uint32_t> & codepoints) {
|
||||
std::vector<std::vector<uint32_t>> split_punctuation_unicode_ascii(const std::vector<uint32_t> & codepoints) {
|
||||
std::vector<std::vector<uint32_t>> results;
|
||||
results.reserve(codepoints.size());
|
||||
std::vector<uint32_t> codepoints_buffer;
|
||||
|
@ -152,13 +153,13 @@ private:
|
|||
codepoints_buffer.clear();
|
||||
uint32_t codepoint = codepoints[offset];
|
||||
|
||||
if (unicode_engine.is_category(codepoint, "PUNCTUATION")) {
|
||||
while (offset < codepoints.size() && unicode_engine.is_category(codepoints[offset], "PUNCTUATION")) {
|
||||
if (is_ascii_punctuation(codepoint) || unicode_engine.is_category(codepoint, "PUNCTUATION")) {
|
||||
while (offset < codepoints.size() && (is_ascii_punctuation(codepoints[offset]) || unicode_engine.is_category(codepoints[offset], "PUNCTUATION"))) {
|
||||
codepoints_buffer.push_back(codepoints[offset]);
|
||||
offset++;
|
||||
}
|
||||
} else {
|
||||
while (offset < codepoints.size() && !unicode_engine.is_category(codepoints[offset], "PUNCTUATION")) {
|
||||
while (offset < codepoints.size() && !(is_ascii_punctuation(codepoints[offset]) || unicode_engine.is_category(codepoints[offset], "PUNCTUATION"))) {
|
||||
codepoints_buffer.push_back(codepoints[offset]);
|
||||
offset++;
|
||||
}
|
||||
|
@ -239,4 +240,13 @@ private:
|
|||
|
||||
return results;
|
||||
}
|
||||
|
||||
static bool is_ascii_punctuation(const uint32_t & codepoint) {
|
||||
static std::unordered_set<uint32_t> ascii_punctuation = {33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46,
|
||||
47, 58, 59, 60, 61, 62, 63, 64, 91, 92, 93, 94, 95, 96,
|
||||
123, 124, 125, 126};
|
||||
auto it = ascii_punctuation.find(codepoint);
|
||||
|
||||
return it != ascii_punctuation.end();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue