optimized performance

This commit is contained in:
Bingxuan Wang 2023-11-21 20:35:38 +08:00
parent 4494a9f655
commit 3a7f0c4cf3
2 changed files with 55 additions and 68 deletions

View file

@ -3271,10 +3271,14 @@ static bool llama_model_load(const std::string & fname, llama_model & model, con
model.hparams.vocab_only = params.vocab_only; model.hparams.vocab_only = params.vocab_only;
llm_load_arch (ml, model); llm_load_arch (ml, model);
std::cout<<"here 1!"<<std::endl;
llm_load_hparams(ml, model); llm_load_hparams(ml, model);
std::cout<<"here 2!"<<std::endl;
llm_load_vocab (ml, model); llm_load_vocab (ml, model);
std::cout<<"here 3!"<<std::endl;
llm_load_print_meta(ml, model); llm_load_print_meta(ml, model);
std::cout<<"here 4!"<<std::endl;
if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { if (model.hparams.n_vocab != model.vocab.id_to_token.size()) {
throw std::runtime_error("vocab size mismatch"); throw std::runtime_error("vocab size mismatch");
@ -6001,79 +6005,60 @@ private:
return bpe_encoded_words; return bpe_encoded_words;
} }
std::vector<std::string> regex_preprocess(const std::vector<std::string> &input, const std::string & regex_expr) { std::vector<size_t> regex_preprocess(const std::wstring & text, const std::vector<size_t> & offsets, const std::wstring& regex_expr) {
std::regex expr(regex_expr); std::wregex expr(regex_expr);
std::vector<std::string> bpe_words; std::vector<size_t> bpe_words; // stroe the offset of each word
// std::wsmatch m; bpe_words.reserve(offsets.size()); // Reserve memory for the approximate size
// // use regex match to get where to split the test string size_t start = 0;
for(auto& text:input) { for ( auto & offset : offsets) {
std::cregex_iterator it(text.data(), text.data() + text.size(), expr); std::wcregex_iterator it(text.data() + start, text.data() + start + offset, expr);
std::cregex_iterator end; std::wcregex_iterator end;
// Print the matches size_t start_idx = 0;
unsigned int start_idx = 0;
while (it != end) { while (it != end) {
std::cmatch match = *it; std::wcmatch match = *it;
std::string match_str = match.str(); if (match.position() > start_idx) {
if(match.position()>start_idx) { bpe_words.emplace_back(match.position() - start_idx);
bpe_words.emplace_back(text.substr(start_idx, match.position()-start_idx));
} }
bpe_words.emplace_back(match_str); bpe_words.emplace_back(match.length());
start_idx = match.position() + match.length(); start_idx = match.position() + match.length();
++it; ++it;
} }
if(start_idx < text.size()) { if (start_idx < offset) {
bpe_words.emplace_back(text.substr(start_idx, text.size()-start_idx)); bpe_words.emplace_back(offset - start_idx);
} }
start += offset;
} }
return bpe_words; return bpe_words;
} }
std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
std::vector<std::string> bpe_words = {text}; std::vector<std::string> regex_bpe_preprocess(const std::string & text, const std::vector<std::wstring> & regex_exprs) {
std::wstring wtext = from_utf8(text);
for(auto & regex_expr : gpt2_regex) { std::vector<size_t> bpe_offsets = {wtext.size()};
bpe_words = regex_preprocess(bpe_words, regex_expr);
for(auto & regex_expr : regex_exprs) {
bpe_offsets = regex_preprocess(wtext, bpe_offsets, regex_expr);
} }
std::vector<std::string> bpe_encoded_words = byte_encoding_process(bpe_words); std::vector<std::string> bpe_words;
bpe_words.reserve(bpe_offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for(size_t & offset : bpe_offsets){
bpe_words.emplace_back(to_utf8(std::wstring(wtext, start, offset)));
start += offset;
}
return bpe_encoded_words; return byte_encoding_process(bpe_words);
}
std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
return regex_bpe_preprocess(text, gpt2_regex);
} }
std::vector<std::string> bpe_deepseek_coder_preprocess(const std::string & text) { std::vector<std::string> bpe_deepseek_coder_preprocess(const std::string & text) {
return regex_bpe_preprocess(text, deepseek_coder_regex);
std::vector<std::string> bpe_words;
std::wstring wtext = from_utf8(text);
// extract all cjk characters
std::wregex expr(L"[\u4e00-\u9fa5\u0800-\u4e00\uac00-\ud7ff]+");
std::wcregex_iterator it(wtext.data(), wtext.data() + wtext.size(), expr);
std::wcregex_iterator end;
unsigned int start_idx = 0;
while (it != end) {
std::wcmatch match = *it;
std::wstring match_str = match.str();
if(match.position()>start_idx) {
bpe_words.emplace_back(to_utf8(wtext.substr(start_idx, match.position()-start_idx)));
}
bpe_words.emplace_back(to_utf8(match_str));
start_idx = match.position() + match.length();
++it;
}
if(start_idx < wtext.size()) {
bpe_words.emplace_back(to_utf8(wtext.substr(start_idx, wtext.size()-start_idx)));
}
for(auto & regex_expr : deepseek_coder_regex) {
bpe_words = regex_preprocess(bpe_words, regex_expr);
}
std::vector<std::string> bpe_encoded_words = byte_encoding_process(bpe_words);
return bpe_encoded_words;
} }
const llama_vocab & vocab; const llama_vocab & vocab;

File diff suppressed because one or more lines are too long