Fix and improve preprocessing
Fix unicode edge case combinations. Split by whitspace in the same pass.
This commit is contained in:
parent
938cb4941a
commit
117b091069
1 changed files with 33 additions and 50 deletions
83
llama.cpp
83
llama.cpp
|
@ -12711,72 +12711,55 @@ struct llm_tokenizer_wpm {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> preprocess(const std::string & text) {
|
std::vector<std::string> preprocess(const std::string & text) {
|
||||||
std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
|
const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
|
||||||
|
std::vector<std::string> words(1, "");
|
||||||
|
|
||||||
// strip accents, strip control, uniformize whitespace,
|
for (const char32_t cpt : cpts_nfd) {
|
||||||
// to lowercase, pad chinese characters, pad punctuation
|
const auto flags = unicode_cpt_flags(cpt);
|
||||||
std::string new_str = "";
|
|
||||||
for (uint32_t code : cpts_nfd) {
|
if (flags.is_whitespace) {
|
||||||
const codepoint_flags flags = unicode_cpt_flags(code);
|
if (words.back().size()) { // finish previous word if any
|
||||||
if (flags.is_accent_mark || flags.is_control) {
|
words.emplace_back();
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
code = unicode_tolower(code);
|
|
||||||
if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
|
assert (!flags.is_separator);
|
||||||
code = ' ';
|
if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
std::string s = unicode_cpt_to_utf8(code);
|
|
||||||
if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
|
const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
|
||||||
new_str += " ";
|
if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
|
||||||
new_str += s;
|
if (words.back().size()) { // finish previous word if any
|
||||||
new_str += " ";
|
words.emplace_back();
|
||||||
|
}
|
||||||
|
words.back() = s; // single char word
|
||||||
|
words.emplace_back(); // start a new word
|
||||||
} else {
|
} else {
|
||||||
new_str += s;
|
words.back() += s; // append char to word
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// split by whitespace
|
if (!words.back().size()) {
|
||||||
uint64_t l = 0;
|
words.pop_back();
|
||||||
uint64_t r = 0;
|
|
||||||
std::vector<std::string> words;
|
|
||||||
while (r < new_str.size()) {
|
|
||||||
// if is whitespace
|
|
||||||
if (isspace(new_str[r], std::locale::classic())) {
|
|
||||||
if (r > l) words.push_back(new_str.substr(l, (r - l)));
|
|
||||||
l = r + 1;
|
|
||||||
r = l;
|
|
||||||
} else {
|
|
||||||
r += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (r > l) {
|
|
||||||
words.push_back(new_str.substr(l, (r - l)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return words;
|
return words;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_ascii_punct(uint32_t code) {
|
static bool is_chinese_char(uint32_t cpt) {
|
||||||
if (code > 0xFF) {
|
return
|
||||||
return false;
|
(cpt >= 0x04E00 && cpt <= 0x09FFF) ||
|
||||||
}
|
(cpt >= 0x03400 && cpt <= 0x04DBF) ||
|
||||||
auto c = char(static_cast<unsigned char>(code));
|
|
||||||
return ispunct(c, std::locale::classic());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_chinese_char(uint32_t cpt) {
|
|
||||||
if ((cpt >= 0x4E00 && cpt <= 0x9FFF) ||
|
|
||||||
(cpt >= 0x3400 && cpt <= 0x4DBF) ||
|
|
||||||
(cpt >= 0x20000 && cpt <= 0x2A6DF) ||
|
(cpt >= 0x20000 && cpt <= 0x2A6DF) ||
|
||||||
(cpt >= 0x2A700 && cpt <= 0x2B73F) ||
|
(cpt >= 0x2A700 && cpt <= 0x2B73F) ||
|
||||||
(cpt >= 0x2B740 && cpt <= 0x2B81F) ||
|
(cpt >= 0x2B740 && cpt <= 0x2B81F) ||
|
||||||
(cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
|
(cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
|
||||||
(cpt >= 0xF900 && cpt <= 0xFAFF) ||
|
(cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
|
||||||
(cpt >= 0x2F800 && cpt <= 0x2FA1F) ||
|
(cpt >= 0x2F800 && cpt <= 0x2FA1F);
|
||||||
(cpt >= 0x3000 && cpt <= 0x303F) ||
|
//(cpt >= 0x3000 && cpt <= 0x303F) ||
|
||||||
(cpt >= 0xFF00 && cpt <= 0xFFEF)) {
|
//(cpt >= 0xFF00 && cpt <= 0xFFEF);
|
||||||
return true; // NOLINT
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_vocab & vocab;
|
const llama_vocab & vocab;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue