This commit is contained in:
Georgi Gerganov 2024-04-27 00:28:36 +03:00
parent a774d7084e
commit c160818ec0
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 45 additions and 9 deletions

View file

@ -1,4 +1,4 @@
#include "unicode.h"
#include "unicode.h"
#include "unicode-data.h"
#include <cassert>
@ -225,7 +225,7 @@ static std::vector<std::string> unicode_byte_encoding_process(const std::vector<
}
static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wtext, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets; // stroe the offset of each word
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
@ -364,7 +364,7 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, const std::vector<size_t> & offsets, const std::wstring & regex_expr) {
std::wregex expr(regex_expr);
std::vector<size_t> bpe_offsets; // stroe the offset of each word
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
@ -391,6 +391,35 @@ static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, c
return bpe_offsets;
}
static std::vector<size_t> unicode_regex_preprocess_fallback(const std::string & text, const std::vector<size_t> & offsets, const std::string & regex_expr) {
std::regex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
std::cregex_iterator end;
int64_t start_idx = 0;
while (it != end) {
std::cmatch match = *it;
if (match.position() > start_idx) {
bpe_offsets.emplace_back(match.position() - start_idx);
}
bpe_offsets.emplace_back(match.length());
start_idx = match.position() + match.length();
++it;
}
if (start_idx < (int64_t) offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
start += offset;
}
return bpe_offsets;
}
static bool unicode_regex_equivalent_wregex_exists(const std::string & regex) {
return unicode_regex_equivalent_wregex.find(regex) != unicode_regex_equivalent_wregex.end();
}
@ -503,7 +532,13 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
const std::wstring & wregex_expr = unicode_regex_equivalent_wregex.at(regex_expr);
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
} else {
throw std::runtime_error("Unicode regex is not found");
try {
bpe_offsets = unicode_regex_preprocess_fallback(text, bpe_offsets, regex_expr);
} catch (std::regex_error & e) {
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
fprintf(stderr, "Regex error: %s\n", e.what());
throw std::runtime_error("Failed to process regex");
}
}
}