tests : fix wstring_convert

This commit is contained in:
Georgi Gerganov 2023-08-14 20:50:15 +03:00
parent aa0551a504
commit 01080a5a51
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 14 additions and 9 deletions

View file

@ -63,6 +63,7 @@ static void llama_log_callback_default(llama_log_level level, const char * text,
#define LLAMA_LOG_WARN(...) llama_log_internal(LLAMA_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_WARN(...) llama_log_internal(LLAMA_LOG_LEVEL_WARN , __VA_ARGS__)
#define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__)
#if !defined(GGML_USE_CUBLAS) && !defined(GGML_USE_METAL) #if !defined(GGML_USE_CUBLAS) && !defined(GGML_USE_METAL)
#include "ggml-alloc.h" #include "ggml-alloc.h"
#define LLAMA_USE_ALLOCATOR #define LLAMA_USE_ALLOCATOR
@ -1988,7 +1989,7 @@ static bool llama_is_eos_token(const llama_vocab& vocab, llama_token token) {
return false; return false;
} }
static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) { static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token token) {
UNUSED(vocab); UNUSED(vocab);
UNUSED(token); UNUSED(token);
// TODO: improve? // TODO: improve?
@ -4400,24 +4401,24 @@ int llama_token_to_str_with_model(const struct llama_model * model, llama_token
if (0 <= token && token < llama_n_vocab_from_model(model)) { if (0 <= token && token < llama_n_vocab_from_model(model)) {
if (llama_is_normal_token(model->vocab, token)) { if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].tok; std::string result = model->vocab.id_to_token[token].tok;
if (llama_vocab_type(model->vocab) == "spm") { if(llama_vocab_type(model->vocab) == "spm") {
result = llama_unescape_whitespace(result); result = llama_unescape_whitespace(result);
} }
if (length < (int) result.length()) { if (length < (int) result.length()) {
return -result.length(); return -result.length();
} }
strcpy(str, result.c_str()); strncpy(str, result.c_str(), result.length());
return result.length(); return result.length();
} else if (llama_is_unknown_token(model->vocab, token)) { } else if (llama_is_unknown_token(model->vocab, token)) {
if (length < 3) { if (length < 3) {
return -3; return -3;
} }
strcpy(str, "\xe2\x96\x85"); strncpy(str, "\xe2\x96\x85", 3);
return 3; return 3;
} else if (llama_is_control_token(model->vocab, token)) { } else if (llama_is_control_token(model->vocab, token)) {
; ;
} else if (llama_is_byte_token(model->vocab, token)) { } else if (llama_is_byte_token(model->vocab, token)) {
if(1 > length) { if (length < 1) {
return -1; return -1;
} }
str[0] = llama_byte_to_char(model->vocab, token); str[0] = llama_byte_to_char(model->vocab, token);
@ -4452,7 +4453,7 @@ int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token,
if (length < (int) result.length()) { if (length < (int) result.length()) {
return -result.length(); return -result.length();
} }
strcpy(str, result.c_str()); strncpy(str, result.c_str(), result.length());
return result.length(); return result.length();
} }
return 0; return 0;
@ -4463,9 +4464,8 @@ std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token
const int length = llama_token_to_str_bpe(ctx, token, result.data(), result.size()); const int length = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
if (length < 0) { if (length < 0) {
result.resize(-length); result.resize(-length);
const int check = llama_token_to_str_bpe(ctx, token, (char*)result.data(), result.size()); const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
GGML_ASSERT(check == -length); GGML_ASSERT(check == -length);
GGML_UNUSED(check);
} else { } else {
result.resize(length); result.resize(length);
} }

View file

@ -106,7 +106,12 @@ int main(int argc, char **argv) {
std::wstring_convert<typename std::codecvt_utf8<wchar_t>, wchar_t> converter; std::wstring_convert<typename std::codecvt_utf8<wchar_t>, wchar_t> converter;
for (wchar_t ch = 0x0000; ch < 0xffff; ++ch) { for (wchar_t ch = 0x0000; ch < 0xffff; ++ch) {
std::wstring wstr(1, ch); std::wstring wstr(1, ch);
std::string str = converter.to_bytes(wstr); std::string str;
try {
str = converter.to_bytes(wstr);
} catch (std::exception & e) {
continue;
}
std::vector<llama_token> tokens = llama_tokenize(ctx, escape_whitespace(str), false); std::vector<llama_token> tokens = llama_tokenize(ctx, escape_whitespace(str), false);
if (tokens.size() == 1) { if (tokens.size() == 1) {
fprintf(stderr, "%s : info: %s tokenized to %d \n", fprintf(stderr, "%s : info: %s tokenized to %d \n",