diff --git a/llama.cpp b/llama.cpp index fd8eaa180..8cf12433f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -119,6 +119,15 @@ static void llama_log_callback_default(llama_log_level level, const char * text, // helpers // +void replace_all(std::string & s, const std::string & search, const std::string & replace) { + for (size_t pos = 0; ; pos += replace.length()) { + pos = s.find(search, pos); + if (pos == std::string::npos) break; + s.erase(pos, search.length()); + s.insert(pos, replace); + } +} + static void zeros(std::ofstream & file, size_t n) { char zero = 0; for (size_t i = 0; i < n; ++i) { @@ -2264,11 +2273,8 @@ static std::string llama_escape_whitespace(const std::string& text) { return result; } -static std::string llama_unescape_whitespace(const std::string& word) { - if (word.length() >= 3 && word.substr(0, 3) == "\xe2\x96\x81") { - return std::string(" ") + word.substr(3); - } - return word; +static void llama_unescape_whitespace(std::string & word) { + replace_all(word, "\xe2\x96\x81", " "); } static size_t utf8_len(char src) { @@ -4902,7 +4908,7 @@ int llama_token_to_str_with_model(const struct llama_model * model, llama_token if (llama_is_normal_token(model->vocab, token)) { std::string result = model->vocab.id_to_token[token].text; if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) { - result = llama_unescape_whitespace(result); + llama_unescape_whitespace(result); } if (length < (int) result.length()) { return -result.length();