From d2f8c9d51bd7472a84a85eecd92ce387c54d1df4 Mon Sep 17 00:00:00 2001 From: goerch Date: Tue, 24 Oct 2023 13:46:44 +0200 Subject: [PATCH] Fix detokenization of non-special added-tokens --- examples/CMakeLists.txt | 2 +- examples/benchmark/CMakeLists.txt | 1 - examples/llava/CMakeLists.txt | 2 +- examples/quantize-stats/CMakeLists.txt | 1 - examples/quantize/CMakeLists.txt | 1 - examples/server/CMakeLists.txt | 3 +-- llama.cpp | 26 ++++++++++++++++++-------- 7 files changed, 21 insertions(+), 15 deletions(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 75b8df676..7fa9d53f3 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -8,7 +8,7 @@ find_package(Threads REQUIRED) # examples -include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR} ../common) if (EMSCRIPTEN) else() diff --git a/examples/benchmark/CMakeLists.txt b/examples/benchmark/CMakeLists.txt index 14916d831..77777ceed 100644 --- a/examples/benchmark/CMakeLists.txt +++ b/examples/benchmark/CMakeLists.txt @@ -2,7 +2,6 @@ set(TARGET benchmark) add_executable(${TARGET} benchmark-matmult.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) -target_include_directories(${TARGET} PRIVATE ../../common) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) add_dependencies(${TARGET} BUILD_INFO) diff --git a/examples/llava/CMakeLists.txt b/examples/llava/CMakeLists.txt index 2d7979ecd..ddbc9475c 100644 --- a/examples/llava/CMakeLists.txt +++ b/examples/llava/CMakeLists.txt @@ -13,7 +13,7 @@ endif() set(TARGET llava) add_executable(${TARGET} llava.cpp) install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llama clip ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE llama clip ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) add_dependencies(${TARGET} BUILD_INFO) diff --git a/examples/quantize-stats/CMakeLists.txt b/examples/quantize-stats/CMakeLists.txt index db182e263..c5c394058 100644 --- a/examples/quantize-stats/CMakeLists.txt +++ b/examples/quantize-stats/CMakeLists.txt @@ -2,5 +2,4 @@ set(TARGET quantize-stats) add_executable(${TARGET} quantize-stats.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) -target_include_directories(${TARGET} PRIVATE ../../common) target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/quantize/CMakeLists.txt b/examples/quantize/CMakeLists.txt index 4a8eed544..47d0be72e 100644 --- a/examples/quantize/CMakeLists.txt +++ b/examples/quantize/CMakeLists.txt @@ -2,7 +2,6 @@ set(TARGET quantize) add_executable(${TARGET} quantize.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) -target_include_directories(${TARGET} PRIVATE ../../common) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) add_dependencies(${TARGET} BUILD_INFO) diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index a23ddcc55..08ad3fb9c 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -1,12 +1,11 @@ set(TARGET server) option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_executable(${TARGET} server.cpp json.hpp httplib.h) install(TARGETS ${TARGET} RUNTIME) target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$ ) -target_link_libraries(${TARGET} PRIVATE common llama clip ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE llama clip ${CMAKE_THREAD_LIBS_INIT}) if (WIN32) TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) endif() diff --git a/llama.cpp b/llama.cpp index 61f30c398..d9e627e4b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9750,6 +9750,8 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch if (0 <= token && token < llama_n_vocab(model)) { switch (llama_vocab_get_type(model->vocab)) { case LLAMA_VOCAB_TYPE_SPM: { + // NOTE: we accept all unsupported token types, + // suppressing them like CONTROL tokens. if (llama_is_normal_token(model->vocab, token)) { std::string result = model->vocab.id_to_token[token].text; llama_unescape_whitespace(result); @@ -9758,6 +9760,13 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch } memcpy(buf, result.c_str(), result.length()); return result.length(); + } else if (llama_is_user_defined_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].text; + if (length < (int) result.length()) { + return -result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT if (length < 3) { return -3; @@ -9772,14 +9781,12 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch } buf[0] = llama_token_to_byte(model->vocab, token); return 1; - } else { - // TODO: for now we accept all unsupported token types, - // suppressing them like CONTROL tokens. - // GGML_ASSERT(false); } break; } case LLAMA_VOCAB_TYPE_BPE: { + // NOTE: we accept all unsupported token types, + // suppressing them like CONTROL tokens. if (llama_is_normal_token(model->vocab, token)) { std::string result = model->vocab.id_to_token[token].text; result = llama_decode_text(result); @@ -9788,12 +9795,15 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch } memcpy(buf, result.c_str(), result.length()); return result.length(); + } else if (llama_is_user_defined_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].text; + if (length < (int) result.length()) { + return -result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); } else if (llama_is_control_token(model->vocab, token)) { ; - } else { - // TODO: for now we accept all unsupported token types, - // suppressing them like CONTROL tokens. - // GGML_ASSERT(false); } break; }