From 67832e55549245564cad6314d3c063ecd2ad4bd0 Mon Sep 17 00:00:00 2001 From: jaime-m-p <> Date: Sun, 5 May 2024 01:20:23 +0200 Subject: [PATCH] llama3 custom regex split: fix \s --- tests/test-tokenizer-random-bpe.py | 13 +++++++------ unicode.cpp | 17 ++++++++++------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/test-tokenizer-random-bpe.py b/tests/test-tokenizer-random-bpe.py index 379154033..3d7d39fce 100644 --- a/tests/test-tokenizer-random-bpe.py +++ b/tests/test-tokenizer-random-bpe.py @@ -144,12 +144,13 @@ def test_custom_texts(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase): ] more_tests = [ - '\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F} - '¼-a', # unicode_ranges_digit, 0x00BC - '½-a', # unicode_ranges_digit, 0x00BD - '¾-a', # unicode_ranges_digit, 0x00BE - 'a 〇b', # unicode_ranges_digit, 0x3007 - 'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms + '\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F} + '¼-a', # unicode_ranges_digit, 0x00BC + '½-a', # unicode_ranges_digit, 0x00BD + '¾-a', # unicode_ranges_digit, 0x00BE + 'a 〇b', # unicode_ranges_digit, 0x3007 + 'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms + '\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM) ] for text in tests+more_tests: diff --git a/unicode.cpp b/unicode.cpp index 31502957e..1c9c34828 100644 --- a/unicode.cpp +++ b/unicode.cpp @@ -281,6 +281,7 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t } } + char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt); int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type); // regex: ?\p{L}+ if (cpt2_type == CODEPOINT_TYPE_LETTER) { @@ -301,17 +302,18 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t continue; } // regex: ?[^\s\p{L}\p{N}]+ - if (cpt2_type != CODEPOINT_TYPE_SEPARATOR && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { + if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { pos += (cpt == ' '); - while (cpt2_type != CODEPOINT_TYPE_SEPARATOR && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { + while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { cpt2_type = _get_cpt_type(++pos); + cpt2 = _get_cpt(pos); } _add_token(pos); continue; } size_t num_whitespaces = 0; - while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_SEPARATOR) { + while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) { num_whitespaces++; } @@ -424,13 +426,14 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & } // regex: ?[^\s\p{L}\p{N}]+[\r\n]* + char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt); int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type); - if (cpt2_type != CODEPOINT_TYPE_SEPARATOR && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { + if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { pos += (cpt == ' '); - while (cpt2_type != CODEPOINT_TYPE_SEPARATOR && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { + while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { cpt2_type = _get_cpt_type(++pos); + cpt2 = _get_cpt(pos); } - char32_t cpt2 = _get_cpt(pos); while (cpt2 == '\r' || cpt2 == '\n') { cpt2 = _get_cpt(++pos); } @@ -440,7 +443,7 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & size_t num_whitespaces = 0; size_t last_end_r_or_n = 0; - while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_SEPARATOR) { + while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) { char32_t cpt2 = _get_cpt(pos+num_whitespaces); if (cpt2 == '\r' || cpt2 == '\n') { last_end_r_or_n = pos + num_whitespaces + 1;