diff --git a/llama.cpp b/llama.cpp index 9c1b2a93e..86c943fc6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -777,8 +777,10 @@ struct llama_vocab { float score; }; + llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + std::unordered_map token_to_id; - std::vector id_to_token; + std::vector id_to_token; // default LLaMA special tokens id special_bos_id = 1; @@ -1406,6 +1408,19 @@ static void llama_model_load_internal( } \ } + std::string tokenizer_name; + GGUF_GET(tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "tokenizer.ggml.model"); + + if (tokenizer_name == "llama") { + vocab.type = LLAMA_VOCAB_TYPE_SPM; + } else if (tokenizer_name == "gpt2") { + vocab.type = LLAMA_VOCAB_TYPE_BPE; + } else { + LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); + LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); + vocab.type = LLAMA_VOCAB_TYPE_SPM; + } + // get hparams kv GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens"); GGUF_GET(hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length"); @@ -1504,6 +1519,7 @@ static void llama_model_load_internal( // hparams LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml->fver)); LLAMA_LOG_INFO("%s: arch = %s\n", __func__, general_arch.c_str()); + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); @@ -2317,40 +2333,22 @@ static bool llama_eval_internal( // tokenizer // -static std::string llama_vocab_type(const llama_vocab & vocab) { - return vocab.token_to_id.size() == 32000 ? "spm": "bpe"; +static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { + return vocab.type; } static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_type(vocab) == "spm") { + if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { return token >= 259; } - if (llama_vocab_type(vocab) == "bpe") { + if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) { return token >= 95; } return false; } -static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_type(vocab) == "spm") { - return token == 0; - } - - // TODO: improve? - return false; -} - -static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_type(vocab) == "spm") { - return token == 1 || token == 2; - } - - // TODO: improve? - return false; -} - static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) { return token == vocab.special_bos_id; } @@ -2359,6 +2357,24 @@ static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) { return token == vocab.special_eos_id; } +static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { + return token == llama_is_bos_token(vocab, token) || token == llama_is_eos_token(vocab, token); + } + + // TODO: improve? + return false; +} + +static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { + return token == 0; + } + + // TODO: improve? + return false; +} + static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) { GGML_UNUSED(vocab); GGML_UNUSED(token); @@ -2374,11 +2390,11 @@ static bool llama_is_unused_token(const llama_vocab & vocab, llama_token token) } static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_type(vocab) == "spm") { + if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { return 3 <= token && token < 259; } - if (llama_vocab_type(vocab) == "bpe") { + if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) { return 1 <= token && token < 95; } @@ -2386,11 +2402,11 @@ static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) { } static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) { - if (llama_vocab_type(vocab) == "spm") { + if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { return byte - 3; } - if (llama_vocab_type(vocab) == "bpe") { + if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) { return byte + 32; } @@ -2398,11 +2414,11 @@ static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) { } static uint8_t llama_char_to_byte(const llama_vocab & vocab, uint8_t ch) { - if (llama_vocab_type(vocab) == "spm") { + if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { return ch + 3; } - if (llama_vocab_type(vocab) == "bpe") { + if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) { return ch - 32; } @@ -5027,7 +5043,7 @@ int llama_tokenize_with_model( llama_token * tokens, int n_max_tokens, bool add_bos) { - auto escape = llama_vocab_type(model->vocab) == "spm"; + auto escape = llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM; auto res = llama_tokenize_internal(model->vocab, text, add_bos, escape); if (n_max_tokens < (int) res.size()) { @@ -5063,7 +5079,7 @@ int llama_token_to_str_with_model(const struct llama_model * model, llama_token if (0 <= token && token < llama_model_n_vocab(model)) { if (llama_is_normal_token(model->vocab, token)) { std::string result = model->vocab.id_to_token[token].tok; - if (llama_vocab_type(model->vocab) == "spm") { + if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) { result = llama_unescape_whitespace(result); } if (length < (int) result.length()) { diff --git a/llama.h b/llama.h index 54081840b..0ea65c1b5 100644 --- a/llama.h +++ b/llama.h @@ -61,6 +61,40 @@ extern "C" { typedef int llama_token; + enum llama_log_level { + LLAMA_LOG_LEVEL_ERROR = 2, + LLAMA_LOG_LEVEL_WARN = 3, + LLAMA_LOG_LEVEL_INFO = 4 + }; + + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece + LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding + }; + + // model file types + enum llama_ftype { + LLAMA_FTYPE_ALL_F32 = 0, + LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed + // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed + LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors + }; + typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token @@ -75,19 +109,6 @@ extern "C" { typedef void (*llama_progress_callback)(float progress, void *ctx); - enum llama_log_level { - LLAMA_LOG_LEVEL_ERROR = 2, - LLAMA_LOG_LEVEL_WARN = 3, - LLAMA_LOG_LEVEL_INFO = 4 - }; - - // Signature for logging events - // Note that text includes the new line character at the end for most events. - // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it - // if it exists. - // It might not exist for progress report where '.' is output repeatedly. - typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data); - struct llama_context_params { uint32_t seed; // RNG seed, -1 for random int32_t n_ctx; // text context @@ -117,28 +138,12 @@ extern "C" { bool embedding; // embedding mode only }; - // model file types - enum llama_ftype { - LLAMA_FTYPE_ALL_F32 = 0, - LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 - // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed - // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed - LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors - }; + // Signature for logging events + // Note that text includes the new line character at the end for most events. + // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it + // if it exists. + // It might not exist for progress report where '.' is output repeatedly. + typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data); // model quantization parameters typedef struct llama_model_quantize_params {