llama3 custom regex split: fix \s

This commit is contained in:
jaime-m-p 2024-05-05 01:20:23 +02:00
parent 8fd849eb90
commit 67832e5554
2 changed files with 17 additions and 13 deletions

View file

@ -144,12 +144,13 @@ def test_custom_texts(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase):
] ]
more_tests = [ more_tests = [
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F} '\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
'¼-a', # unicode_ranges_digit, 0x00BC '¼-a', # unicode_ranges_digit, 0x00BC
'½-a', # unicode_ranges_digit, 0x00BD '½-a', # unicode_ranges_digit, 0x00BD
'¾-a', # unicode_ranges_digit, 0x00BE '¾-a', # unicode_ranges_digit, 0x00BE
'a b', # unicode_ranges_digit, 0x3007 'a b', # unicode_ranges_digit, 0x3007
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms 'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
] ]
for text in tests+more_tests: for text in tests+more_tests:

View file

@ -281,6 +281,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
} }
} }
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type); int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
// regex: <space>?\p{L}+ // regex: <space>?\p{L}+
if (cpt2_type == CODEPOINT_TYPE_LETTER) { if (cpt2_type == CODEPOINT_TYPE_LETTER) {
@ -301,17 +302,18 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
continue; continue;
} }
// regex: <space>?[^\s\p{L}\p{N}]+ // regex: <space>?[^\s\p{L}\p{N}]+
if (cpt2_type != CODEPOINT_TYPE_SEPARATOR && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (cpt2_type != CODEPOINT_TYPE_SEPARATOR && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
cpt2_type = _get_cpt_type(++pos); cpt2_type = _get_cpt_type(++pos);
cpt2 = _get_cpt(pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_SEPARATOR) { while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
num_whitespaces++; num_whitespaces++;
} }
@ -424,13 +426,14 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]* // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type); int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
if (cpt2_type != CODEPOINT_TYPE_SEPARATOR && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (cpt2_type != CODEPOINT_TYPE_SEPARATOR && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
cpt2_type = _get_cpt_type(++pos); cpt2_type = _get_cpt_type(++pos);
cpt2 = _get_cpt(pos);
} }
char32_t cpt2 = _get_cpt(pos);
while (cpt2 == '\r' || cpt2 == '\n') { while (cpt2 == '\r' || cpt2 == '\n') {
cpt2 = _get_cpt(++pos); cpt2 = _get_cpt(++pos);
} }
@ -440,7 +443,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0; size_t last_end_r_or_n = 0;
while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_SEPARATOR) { while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
char32_t cpt2 = _get_cpt(pos+num_whitespaces); char32_t cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') { if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1; last_end_r_or_n = pos + num_whitespaces + 1;