diff --git a/Makefile b/Makefile index 0ca968b30..079527ca5 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ BUILD_TARGETS = \ TEST_TARGETS = \ tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \ tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \ - tests/test-tokenizer-0-falcon tests/test-tokenizer-0-deepseek-coder \ + tests/test-tokenizer-0-falcon tests/test-tokenizer-0-deepseek_coder tests/test-tokenizer-0-deepseek_llm \ tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe # Code coverage output files @@ -72,6 +72,8 @@ test: $(TEST_TARGETS) ./$$test_target $(CURDIR)/models/ggml-vocab-falcon.gguf; \ elif [ "$$test_target" = "tests/test-tokenizer-0-deepseek-coder" ]; then \ ./$$test_target $(CURDIR)/models/ggml-vocab-deepseek-coder.gguf; \ + elif [ "$$test_target" = "tests/test-tokenizer-0-deepseek_llm" ]; then \ + ./$$test_target $(CURDIR)/models/ggml-vocab-deepseek-llm.gguf $(CURDIR)/README.md; \ elif [ "$$test_target" = "tests/test-tokenizer-1-llama" ]; then \ continue; \ elif [ "$$test_target" = "tests/test-tokenizer-1-bpe" ]; then \ @@ -734,6 +736,9 @@ tests/test-tokenizer-0-llama: tests/test-tokenizer-0-llama.cpp ggml.o llama.o $( tests/test-tokenizer-0-deepseek-coder: tests/test-tokenizer-0-deepseek-coder.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) +tests/test-tokenizer-0-deepseek_llm: tests/test-tokenizer-0-deepseek_llm.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + tests/test-tokenizer-1-bpe: tests/test-tokenizer-1-bpe.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a26156b39..6b61b73bb 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -195,6 +195,8 @@ class Model: return PersimmonModel if model_name_lower == "deepseekcoder": return DeepseekCoderModel + if model_name_lower == "deepseekllm": + return DeepseekLLMModel return Model def _is_model_safetensors(self) -> bool: @@ -865,6 +867,9 @@ class DeepseekCoderModel(Model): def set_vocab(self): self._set_vocab_gpt2("deepseek_coder") +class DeepseekLLMModel(DeepseekCoderModel): + def set_vocab(self): + self._set_vocab_gpt2("deepseek_llm") class StableLMModel(Model): def set_gguf_parameters(self): diff --git a/llama.cpp b/llama.cpp index 71fb44a05..005f60d20 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2385,14 +2385,19 @@ static void llm_load_vocab( vocab.special_unk_id = 0; vocab.special_sep_id = -1; vocab.special_pad_id = -1; - } else if (tokenizer_name == "gpt2" || tokenizer_name == "deepseek_coder") { - if(tokenizer_name == "gpt2") { + } else { + if (tokenizer_name == "gpt2") { vocab.type = LLAMA_VOCAB_TYPE_BPE; - } - else if (tokenizer_name == "deepseek_coder") { + } else if (tokenizer_name == "deepseek_coder") { vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKCODER; + } else if (tokenizer_name == "deepseek_llm") { + vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKLLM; + } else { + LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); + LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); + vocab.type = LLAMA_VOCAB_TYPE_SPM; + return; } - // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); if (merges_keyidx == -1) { @@ -2424,11 +2429,6 @@ static void llm_load_vocab( vocab.special_unk_id = -1; vocab.special_sep_id = -1; vocab.special_pad_id = -1; - } else { - LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); - LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); - - vocab.type = LLAMA_VOCAB_TYPE_SPM; } } @@ -2605,7 +2605,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { // hparams LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); - LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : (vocab.type == LLAMA_VOCAB_TYPE_BPE ? "BPE" : "DEEPSEEKCODER")); // TODO: fix + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); @@ -5956,10 +5956,20 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; std::vector word_collection; - if(vocab.type == LLAMA_VOCAB_TYPE_BPE) + switch (vocab.type) + { + case LLAMA_VOCAB_TYPE_BPE: word_collection = bpe_gpt2_preprocess(text); - else if(vocab.type==LLAMA_VOCAB_TYPE_DEEPSEEKCODER) + break; + case LLAMA_VOCAB_TYPE_DEEPSEEKCODER: word_collection = bpe_deepseek_coder_preprocess(text); + break; + case LLAMA_VOCAB_TYPE_DEEPSEEKLLM: + word_collection = bpe_deepseek_llm_preprocess(text); + break; + default: + break; + } symbols_final.clear(); @@ -6159,6 +6169,10 @@ private: return regex_bpe_preprocess(text, deepseek_coder_regex); } + std::vector bpe_deepseek_llm_preprocess(const std::string & text) { + return regex_bpe_preprocess(text, deepseek_llm_regex); + } + const llama_vocab & vocab; std::vector symbols; @@ -6350,6 +6364,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & } } break; case LLAMA_VOCAB_TYPE_DEEPSEEKCODER: + case LLAMA_VOCAB_TYPE_DEEPSEEKLLM: case LLAMA_VOCAB_TYPE_BPE: { for (const auto & fragment: fragment_buffer) @@ -9467,6 +9482,7 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch break; } case LLAMA_VOCAB_TYPE_DEEPSEEKCODER: + case LLAMA_VOCAB_TYPE_DEEPSEEKLLM: case LLAMA_VOCAB_TYPE_BPE: { if (llama_is_normal_token(model->vocab, token)) { std::string result = model->vocab.id_to_token[token].text; diff --git a/llama.h b/llama.h index 45baa9333..9b7387e91 100644 --- a/llama.h +++ b/llama.h @@ -70,6 +70,7 @@ extern "C" { LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding LLAMA_VOCAB_TYPE_DEEPSEEKCODER = 2, // deepseek coder + LLAMA_VOCAB_TYPE_DEEPSEEKLLM = 3, // deepseek coder }; enum llama_token_type { diff --git a/models/ggml-vocab-deepseek-llm.gguf b/models/ggml-vocab-deepseek-llm.gguf new file mode 100644 index 000000000..3deb18212 Binary files /dev/null and b/models/ggml-vocab-deepseek-llm.gguf differ diff --git a/tests/test-tokenizer-0-deepseek_llm.cpp b/tests/test-tokenizer-0-deepseek_llm.cpp new file mode 100644 index 000000000..e60731071 --- /dev/null +++ b/tests/test-tokenizer-0-deepseek_llm.cpp @@ -0,0 +1,188 @@ +#include "llama.h" +#include "common.h" +#include "console.h" + +#include +#include +#include +#include +#include + +// generate using test-tokenizer-0-falcon.py +static const std::map> & k_tests() { + static std::map> _k_tests = { + { "" , { }, }, + { " " , { 207, }, }, + { " " , { 243, }, }, + { " " , { 300, }, }, + { "\t" , { 184, }, }, + { "\n" , { 185, }, }, + { "\t\n" , { 184, 185, }, }, + { "Hello world" , { 17464, 1843, }, }, + { " Hello world" , { 37727, 1843, }, }, + { "Hello World" , { 17464, 5427, }, }, + { " Hello World" , { 37727, 5427, }, }, + { " Hello World!" , { 37727, 5427, 0, }, }, + { "Hello, world!" , { 17464, 11, 1843, 0, }, }, + { " Hello, world!" , { 37727, 11, 1843, 0, }, }, + { " this is 🦙.cpp" , { 437, 317, 12356, 99, 234, 13, 14743, }, }, + { "w048 7tuijk dsdfhu" , { 86, 15, 19, 23, 207, 22, 83, 3970, 27519, 26016, 3944, 14025, }, }, + { "нещо на Български" , { 1603, 6476, 620, 91754, }, }, + { "កាន់តែពិសេសអាចខលចេញ" , { 71374, 209, 71374, 114, 71374, 228, 155, 240, 220, 71374, 224, 155, 240, 211, 71374, 231, 71374, 115, 71374, 240, 155, 240, 210, 71374, 240, 71374, 95, 71374, 114, 71374, 214, 71374, 210, 71374, 236, 71374, 214, 155, 240, 210, 71374, 218, }, }, + { "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", { 10044, 95300, 334, 8754, 8, 33701, 114, 350, 222, 10044, 221, 104, 46713, 334, 34732, 996, 24250, 262, 80923, 8, 207, 37103, 214, 334, 5956, 89213, 344, 643, 895, 1377, 10728, 8, }, }, + { "Hello" , { 17464, }, }, + { " Hello" , { 37727, }, }, + { " Hello" , { 207, 37727, }, }, + { " Hello" , { 243, 37727, }, }, + { " Hello" , { 300, 37727, }, }, + { " Hello\n Hello" , { 300, 37727, 185, 300, 37727, }, }, + { "\n =" , { 185, 403, }, }, + { "' era" , { 6, 2906, }, }, + { "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~", { 17464, 11, 320, 6, 436, 0, 1724, 418, 340, 33701, 210, 3025, 19017, 612, 9407, 2681, 16, 18, 16, 19, 16, 20, 16, 1398, 68940, 239, }, }, + + }; + + return _k_tests; +} + +int main(int argc, char **argv) { + if (argc < 2) { + fprintf(stderr, "Usage: %s vocab-file [text-file]\n", argv[0]); + return 1; + } + + const std::string fname = argv[1]; + + std::string fname_text; + if (argc > 2) { + fname_text = argv[2]; + } + + fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str()); + + llama_model * model; + llama_context * ctx; + + llama_backend_init(false); + + // load the vocab + { + auto mparams = llama_model_default_params(); + + mparams.vocab_only = true; + + model = llama_load_model_from_file(fname.c_str(), mparams); + + if (model == NULL) { + fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); + return 1; + } + + auto cparams = llama_context_default_params(); + + ctx = llama_new_context_with_model(model, cparams); + + if (ctx == NULL) { + fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); + llama_free_model(model); + return 1; + } + } + + if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_DEEPSEEKLLM) { + fprintf(stderr, "%s : error: vocab type is not DEEPSEEKLLM\n", __func__); + llama_free_model(model); + llama_free(ctx); + return 2; + } + +#ifdef _WIN32 + // We need this for unicode console support + console::init(false, false); + atexit([]() { console::cleanup(); }); +#endif + + bool success = true; + + for (const auto & test_kv : k_tests()) { + const std::vector res = llama_tokenize(ctx, test_kv.first, false); + + printf("\n"); + printf("src: '%s'\n", test_kv.first.c_str()); + printf("res: '%s'\n", llama_detokenize_bpe(ctx, res).c_str()); + printf("tok: "); + for (const auto & tok : res) { + printf("%d ", tok); + } + printf("\n"); + + bool correct = res.size() == test_kv.second.size(); + for (int i = 0; i < (int) res.size() && correct; ++i) { + if (test_kv.second[i] != res[i]) { + correct = false; + } + } + + if (!correct) { + fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); + fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__, + llama_detokenize_bpe(ctx, res).c_str(), + llama_detokenize_bpe(ctx, test_kv.second).c_str()); + fprintf(stderr, "%s : expected tokens: ", __func__); + for (const auto & t : test_kv.second) { + fprintf(stderr, "%6d, ", t); + } + fprintf(stderr, "\n"); + fprintf(stderr, "%s : got tokens: ", __func__); + for (const auto & t : res) { + fprintf(stderr, "%6d, ", t); + } + fprintf(stderr, "\n"); + + success = false; + } + } + + if (!fname_text.empty()) { + fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str()); + + std::string text; + { + std::ifstream ifs(fname_text); + if (!ifs) { + fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_text.c_str()); + return 1; + } + text = std::string(std::istreambuf_iterator(ifs), std::istreambuf_iterator()); + } + + fprintf(stderr, "%s : text size: %zu\n", __func__, text.size()); + + const std::vector res = llama_tokenize(ctx, text, false); + + fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size()); + + { + const std::string fname_out = fname_text + ".tokcpp"; + + std::ofstream ofs(fname_out); + if (!ofs) { + fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str()); + return 1; + } + + for (const auto & tok : res) { + ofs << tok << " '" << llama_detokenize_bpe(ctx, std::vector{tok}) << "'" << std::endl; + } + } + + fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str()); + } + + llama_free_model(model); + llama_free(ctx); + + llama_backend_free(); + + return success ? 0 : 3; +} diff --git a/tests/test-tokenizer-0-deepseek_llm.py b/tests/test-tokenizer-0-deepseek_llm.py new file mode 100644 index 000000000..b99840e1b --- /dev/null +++ b/tests/test-tokenizer-0-deepseek_llm.py @@ -0,0 +1,83 @@ +# tests with BPE tokenizer + +import argparse + +from transformers import AutoTokenizer + +parser = argparse.ArgumentParser() +parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file") +parser.add_argument("--fname-tok", help="path to a text file to tokenize") +args = parser.parse_args() + +dir_tokenizer = args.dir_tokenizer + +tokenizer = AutoTokenizer.from_pretrained(dir_tokenizer) + +tests = [ + "", + " ", + " ", + " ", + "\t", + "\n", + "\t\n", + "Hello world", + " Hello world", + "Hello World", + " Hello World", + " Hello World!", + "Hello, world!", + " Hello, world!", + " this is 🦙.cpp", + "w048 7tuijk dsdfhu", + "нещо на Български", + "កាន់តែពិសេសអាចខលចេញ", + "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", + "Hello", + " Hello", + " Hello", + " Hello", + " Hello", + " Hello\n Hello", + "\n =", + "' era", + "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~", +] + +for text in tests: + print('text: ', text) + print(tokenizer.encode(text)) + print(tokenizer.decode(tokenizer.encode(text))) + +print("\n\ntests for C++:\n") +for text in tests: + res = tokenizer.encode(text) + + k = text.replace('\n', '\\n') + k = k.replace('\t', '\\t') + k = '"' + k + '"' + print("{ %-24s, { " % k, end='') + for x in res: + print("%7d," % x, end='') + print(" }, },") + +print(tokenizer.encode('hello')) +print(tokenizer.encode('world')) +print(tokenizer.encode(' world')) +print(tokenizer.encode('hello world')) + +fname_tok = args.fname_tok +if fname_tok: + print('tokenizing file: ', fname_tok) + fname_out = fname_tok + '.tok' + with open(fname_tok, 'r', encoding='utf-8') as f: + lines = f.readlines() + s = ''.join(lines) + res = tokenizer.encode(s) + # write to file + with open(fname_out, 'w', encoding='utf-8') as f: + for x in res: + f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n') + print('len(res): ', len(res)) + print('len(lines): ', len(lines)) + print('results written to: ', fname_out) diff --git a/unicode.h b/unicode.h index dc5188970..44cc7439a 100644 --- a/unicode.h +++ b/unicode.h @@ -486,6 +486,15 @@ static const std::vector deepseek_coder_regex = { L"[\U00000030-\U00000039\U000000B2-\U000000B3\U000000B9-\U000000B9\U00000660-\U00000669\U000006F0-\U000006F9\U000007C0-\U000007C9\U00000966-\U0000096F\U000009E6-\U000009EF\U00000A66-\U00000A6F\U00000AE6-\U00000AEF\U00000B66-\U00000B6F\U00000BE6-\U00000BEF\U00000C66-\U00000C6F\U00000CE6-\U00000CEF\U00000D66-\U00000D6F\U00000DE6-\U00000DEF\U00000E50-\U00000E59\U00000ED0-\U00000ED9\U00000F20-\U00000F29\U00001040-\U00001049\U00001090-\U00001099\U00001369-\U00001371\U000017E0-\U000017E9\U00001810-\U00001819\U00001946-\U0000194F\U000019D0-\U000019DA\U00001A80-\U00001A89\U00001A90-\U00001A99\U00001B50-\U00001B59\U00001BB0-\U00001BB9\U00001C40-\U00001C49\U00001C50-\U00001C59\U00002070-\U00002070\U00002074-\U00002079\U00002080-\U00002089\U00002460-\U00002468\U00002474-\U0000247C\U00002488-\U00002490\U000024EA-\U000024EA\U000024F5-\U000024FD\U000024FF-\U000024FF\U00002776-\U0000277E\U00002780-\U00002788\U0000278A-\U00002792\U0000A620-\U0000A629\U0000A8D0-\U0000A8D9\U0000A900-\U0000A909\U0000A9D0-\U0000A9D9\U0000A9F0-\U0000A9F9\U0000AA50-\U0000AA59\U0000ABF0-\U0000ABF9\U0000FF10-\U0000FF19\U000104A0-\U000104A9\U00010A40-\U00010A43\U00010D30-\U00010D39\U00010E60-\U00010E68\U00011052-\U0001105A\U00011066-\U0001106F\U000110F0-\U000110F9\U00011136-\U0001113F\U000111D0-\U000111D9\U000112F0-\U000112F9\U00011450-\U00011459\U000114D0-\U000114D9\U00011650-\U00011659\U000116C0-\U000116C9\U00011730-\U00011739\U000118E0-\U000118E9\U00011950-\U00011959\U00011C50-\U00011C59\U00011D50-\U00011D59\U00011DA0-\U00011DA9\U00016A60-\U00016A69\U00016B50-\U00016B59\U0001D7CE-\U0001D7FF\U0001E140-\U0001E149\U0001E2F0-\U0001E2F9\U0001E950-\U0001E959\U0001F100-\U0001F10A\U0001FBF0-\U0001FBF9]" }; +static const std::vector deepseek_llm_regex = { + L"[\r\n]", + L"\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", + L"\\s?[\u0021-\u002f\u003a-\u007e\uff01-\uff0f\uff1a-\uff5e\u2018-\u201f\u3000-\u3002]+", + L"\\s+$", + L"[\u4e00-\u9fa5\u0800-\u4e00\uac00-\ud7ff]+", + L"[\U00000030-\U00000039\U000000B2-\U000000B3\U000000B9-\U000000B9\U00000660-\U00000669\U000006F0-\U000006F9\U000007C0-\U000007C9\U00000966-\U0000096F\U000009E6-\U000009EF\U00000A66-\U00000A6F\U00000AE6-\U00000AEF\U00000B66-\U00000B6F\U00000BE6-\U00000BEF\U00000C66-\U00000C6F\U00000CE6-\U00000CEF\U00000D66-\U00000D6F\U00000DE6-\U00000DEF\U00000E50-\U00000E59\U00000ED0-\U00000ED9\U00000F20-\U00000F29\U00001040-\U00001049\U00001090-\U00001099\U00001369-\U00001371\U000017E0-\U000017E9\U00001810-\U00001819\U00001946-\U0000194F\U000019D0-\U000019DA\U00001A80-\U00001A89\U00001A90-\U00001A99\U00001B50-\U00001B59\U00001BB0-\U00001BB9\U00001C40-\U00001C49\U00001C50-\U00001C59\U00002070-\U00002070\U00002074-\U00002079\U00002080-\U00002089\U00002460-\U00002468\U00002474-\U0000247C\U00002488-\U00002490\U000024EA-\U000024EA\U000024F5-\U000024FD\U000024FF-\U000024FF\U00002776-\U0000277E\U00002780-\U00002788\U0000278A-\U00002792\U0000A620-\U0000A629\U0000A8D0-\U0000A8D9\U0000A900-\U0000A909\U0000A9D0-\U0000A9D9\U0000A9F0-\U0000A9F9\U0000AA50-\U0000AA59\U0000ABF0-\U0000ABF9\U0000FF10-\U0000FF19\U000104A0-\U000104A9\U00010A40-\U00010A43\U00010D30-\U00010D39\U00010E60-\U00010E68\U00011052-\U0001105A\U00011066-\U0001106F\U000110F0-\U000110F9\U00011136-\U0001113F\U000111D0-\U000111D9\U000112F0-\U000112F9\U00011450-\U00011459\U000114D0-\U000114D9\U00011650-\U00011659\U000116C0-\U000116C9\U00011730-\U00011739\U000118E0-\U000118E9\U00011950-\U00011959\U00011C50-\U00011C59\U00011D50-\U00011D59\U00011DA0-\U00011DA9\U00016A60-\U00016A69\U00016B50-\U00016B59\U0001D7CE-\U0001D7FF\U0001E140-\U0001E149\U0001E2F0-\U0001E2F9\U0001E950-\U0001E959\U0001F100-\U0001F10A\U0001FBF0-\U0001FBF9]" + }; + inline std::wstring from_utf8(const std::string& s) { std::wstring_convert> conv;