llama : towards llama3 tokenization support (wip)

This commit is contained in:
Georgi Gerganov 2024-04-26 14:55:03 +03:00
parent ed42711b90
commit 4907e41aa7
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
8 changed files with 298 additions and 121 deletions

View file

@ -215,50 +215,50 @@ class Model(ABC):
except KeyError: except KeyError:
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
@staticmethod # @staticmethod
def from_model_architecture(model_architecture): # def from_model_architecture(model_architecture):
if model_architecture == "GPTNeoXForCausalLM": # if model_architecture == "GPTNeoXForCausalLM":
return GPTNeoXModel # return GPTNeoXModel
if model_architecture == "BloomForCausalLM": # if model_architecture == "BloomForCausalLM":
return BloomModel # return BloomModel
if model_architecture == "MPTForCausalLM": # if model_architecture == "MPTForCausalLM":
return MPTModel # return MPTModel
if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"): # if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
return BaichuanModel # return BaichuanModel
if model_architecture in ("FalconForCausalLM", "RWForCausalLM"): # if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
return FalconModel # return FalconModel
if model_architecture == "GPTBigCodeForCausalLM": # if model_architecture == "GPTBigCodeForCausalLM":
return StarCoderModel # return StarCoderModel
if model_architecture == "GPTRefactForCausalLM": # if model_architecture == "GPTRefactForCausalLM":
return RefactModel # return RefactModel
if model_architecture == "PersimmonForCausalLM": # if model_architecture == "PersimmonForCausalLM":
return PersimmonModel # return PersimmonModel
if model_architecture == "LlamaForCausalLM": # if model_architecture == "LlamaForCausalLM":
return DeepseekCoderModel # return LlamaModel
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"): # if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return StableLMModel # return StableLMModel
if model_architecture == "QWenLMHeadModel": # if model_architecture == "QWenLMHeadModel":
return QwenModel # return QwenModel
if model_architecture == "Qwen2ForCausalLM": # if model_architecture == "Qwen2ForCausalLM":
return Model # return Model
if model_architecture == "MixtralForCausalLM": # if model_architecture == "MixtralForCausalLM":
return MixtralModel # return MixtralModel
if model_architecture == "GPT2LMHeadModel": # if model_architecture == "GPT2LMHeadModel":
return GPT2Model # return GPT2Model
if model_architecture == "PhiForCausalLM": # if model_architecture == "PhiForCausalLM":
return Phi2Model # return Phi2Model
if model_architecture == "PlamoForCausalLM": # if model_architecture == "PlamoForCausalLM":
return PlamoModel # return PlamoModel
if model_architecture == "CodeShellForCausalLM": # if model_architecture == "CodeShellForCausalLM":
return CodeShellModel # return CodeShellModel
if model_architecture == "OrionForCausalLM": # if model_architecture == "OrionForCausalLM":
return OrionModel # return OrionModel
if model_architecture == "InternLM2ForCausalLM": # if model_architecture == "InternLM2ForCausalLM":
return InternLM2Model # return InternLM2Model
if model_architecture == "MiniCPMForCausalLM": # if model_architecture == "MiniCPMForCausalLM":
return MiniCPMModel # return MiniCPMModel
if model_architecture == "BertModel": # if model_architecture == "BertModel":
return BertModel # return BertModel
@staticmethod @staticmethod
def from_model_name(model_name: str): def from_model_name(model_name: str):
@ -281,10 +281,8 @@ class Model(ABC):
return RefactModel return RefactModel
if model_name_lower == "persimmon": if model_name_lower == "persimmon":
return PersimmonModel return PersimmonModel
if model_name_lower == "deepseekcoder": if model_name_lower in ("llama", "deepseekcoder", "deepseekllm"):
return DeepseekCoderModel return LlamaModel
if model_name_lower == "deepseekllm":
return DeepseekLLMModel
return Model return Model
def _is_model_safetensors(self) -> bool: def _is_model_safetensors(self) -> bool:
@ -376,7 +374,6 @@ class Model(ABC):
return tokens, toktypes return tokens, toktypes
def _set_vocab_gpt2(self, tokenizer_model:str = "gpt2") -> None: def _set_vocab_gpt2(self, tokenizer_model:str = "gpt2") -> None:
tokens, toktypes = self.get_basic_vocab() tokens, toktypes = self.get_basic_vocab()
self.gguf_writer.add_tokenizer_model(tokenizer_model) self.gguf_writer.add_tokenizer_model(tokenizer_model)
@ -1312,31 +1309,7 @@ class PersimmonModel(Model):
n_dims = len(data.shape) n_dims = len(data.shape)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data) self.gguf_writer.add_tensor(new_name, data)
@Model.register("LlamaForCausalLM")
class DeepseekCoderModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
def set_gguf_parameters(self):
super().set_gguf_parameters()
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
def set_vocab(self):
self._set_vocab_gpt2("deepseek_coder")
class DeepseekLLMModel(DeepseekCoderModel):
def set_vocab(self):
self._set_vocab_gpt2("deepseek_llm")
@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM") @Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
class StableLMModel(Model): class StableLMModel(Model):
@ -1479,6 +1452,11 @@ class LlamaModel(Model):
self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
# Same as super class, but permuting q_proj, k_proj # Same as super class, but permuting q_proj, k_proj
def write_tensors(self): def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))

View file

@ -2114,6 +2114,7 @@ struct llama_vocab {
ttype type; ttype type;
}; };
enum llm_arch arch = LLM_ARCH_UNKNOWN;
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
std::unordered_map<token, id> token_to_id; std::unordered_map<token, id> token_to_id;
@ -4243,10 +4244,6 @@ static void llm_load_vocab(
} else { } else {
if (tokenizer_name == "gpt2") { if (tokenizer_name == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE; vocab.type = LLAMA_VOCAB_TYPE_BPE;
} else if (tokenizer_name == "deepseek_coder") {
vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKCODER;
} else if (tokenizer_name == "deepseek_llm") {
vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKLLM;
} else { } else {
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
@ -4287,6 +4284,8 @@ static void llm_load_vocab(
vocab.special_cls_id = -1; vocab.special_cls_id = -1;
vocab.special_mask_id = -1; vocab.special_mask_id = -1;
} }
vocab.arch = model.arch;
} }
const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
@ -11784,10 +11783,9 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
auto buf = token_data.text.substr(3, 2); auto buf = token_data.text.substr(3, 2);
return strtol(buf.c_str(), NULL, 16); return strtol(buf.c_str(), NULL, 16);
} }
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
case LLAMA_VOCAB_TYPE_BPE: { case LLAMA_VOCAB_TYPE_BPE: {
GGML_ASSERT(false); GGML_ASSERT(false);
return unicode_utf8_to_byte(token_data.text); return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT?
} }
case LLAMA_VOCAB_TYPE_WPM: { case LLAMA_VOCAB_TYPE_WPM: {
GGML_ASSERT(false); GGML_ASSERT(false);
@ -11812,7 +11810,6 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
return vocab.token_to_id.at(buf2); return vocab.token_to_id.at(buf2);
} }
case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
case LLAMA_VOCAB_TYPE_BPE: { case LLAMA_VOCAB_TYPE_BPE: {
return vocab.token_to_id.at(unicode_byte_to_utf8(ch)); return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
} }
@ -12014,33 +12011,43 @@ struct llm_tokenizer_bpe {
std::vector<std::string> word_collection; std::vector<std::string> word_collection;
switch (vocab.type) { switch (vocab.type) {
case LLAMA_VOCAB_TYPE_BPE: case LLAMA_VOCAB_TYPE_BPE:
word_collection = unicode_regex_split(text, { switch (vocab.arch) {
"\\p{P}+", // TODO: how to detect deepseek and llama v3 models?
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", //case LLM_ARCH_LLAMA:
"\\p{N}+", //case LLM_ARCH_DEEPSEEK_CODER:
"[0-9][0-9][0-9]" // word_collection = unicode_regex_split(text, {
}); // "[\r\n]",
break; // "\\s?\\p{L}+",
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER: // "\\s?\\p{P}+",
word_collection = unicode_regex_split(text, { // "[一-龥ࠀ-一가-퟿]+",
"[\r\n]", // "\\p{N}+"
"\\s?\\p{L}+", // });
"\\s?\\p{P}+", // break;
"[一-龥ࠀ-一가-퟿]+", //case LLM_ARCH_DEEPSEEK_LLM:
"\\p{N}+" // word_collection = unicode_regex_split(text, {
}); // "[\r\n]",
break; // "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ--ℝℤΩℨK--ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA--z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
case LLAMA_VOCAB_TYPE_DEEPSEEKLLM: // "\\s?[!-/:-~---‟ -。]+",
word_collection = unicode_regex_split(text, { // "\\s+$",
"[\r\n]", // "[一-龥ࠀ-一가-퟿]+",
"\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ--ℝℤΩℨK--ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA--z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", // "\\p{N}+"
"\\s?[!-/:-~---‟ -。]+", // });
"\\s+$", // break;
"[一-龥ࠀ-一가-퟿]+", default:
"\\p{N}+" // default regex for BPE tokenization pre-processing
}); {
word_collection = unicode_regex_split(text, {
"\\p{P}+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
"\\p{N}+",
"[0-9][0-9][0-9]"
});
}
break;
}
break; break;
default: default:
GGML_ASSERT(false);
break; break;
} }
@ -12486,8 +12493,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
output.push_back(vocab.special_eos_id); output.push_back(vocab.special_eos_id);
} }
} break; } break;
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
case LLAMA_VOCAB_TYPE_DEEPSEEKLLM:
case LLAMA_VOCAB_TYPE_BPE: case LLAMA_VOCAB_TYPE_BPE:
{ {
if (add_special && vocab.special_add_bos == 1) { if (add_special && vocab.special_add_bos == 1) {
@ -17188,8 +17193,6 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
} }
break; break;
} }
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
case LLAMA_VOCAB_TYPE_DEEPSEEKLLM:
case LLAMA_VOCAB_TYPE_BPE: { case LLAMA_VOCAB_TYPE_BPE: {
// NOTE: we accept all unsupported token types, // NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens. // suppressing them like CONTROL tokens.

View file

@ -67,8 +67,6 @@ extern "C" {
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
LLAMA_VOCAB_TYPE_DEEPSEEKCODER = 4, // Deepseek Coder
LLAMA_VOCAB_TYPE_DEEPSEEKLLM = 5, // Deepseek LLM
}; };
// note: these values should be synchronized with ggml_rope // note: these values should be synchronized with ggml_rope

View file

@ -41,12 +41,13 @@ llama_test(test-quantize-perf.cpp)
llama_test(test-sampling.cpp) llama_test(test-sampling.cpp)
llama_test(test-chat-template.cpp) llama_test(test-chat-template.cpp)
# TODO: tmp disabled LLaMA v3 and Deepseek tests
llama_test(test-tokenizer-0-llama.cpp NAME test-tokenizer-0-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
#llama_test(test-tokenizer-0-llama-v3.cpp NAME test-tokenizer-0-llama-v3 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-v3.gguf)
llama_test(test-tokenizer-0-falcon.cpp NAME test-tokenizer-0-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
llama_test(test-tokenizer-0-llama.cpp NAME test-tokenizer-0-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) #llama_test(test-tokenizer-0-deepseek-coder.cpp NAME test-tokenizer-0-deepseek-coder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-coder.gguf)
llama_test(test-tokenizer-0-falcon.cpp NAME test-tokenizer-0-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf) #llama_test(test-tokenizer-0-deepseek-llm.cpp NAME test-tokenizer-0-deepseek-llm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-llm.gguf)
llama_test(test-tokenizer-0-deepseek-coder.cpp NAME test-tokenizer-0-deepseek-coder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-coder.gguf)
llama_test(test-tokenizer-0-deepseek-llm.cpp NAME test-tokenizer-0-deepseek-llm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-llm.gguf)
llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf) llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)

View file

@ -89,8 +89,8 @@ int main(int argc, char **argv) {
} }
} }
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_DEEPSEEKCODER) { if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
fprintf(stderr, "%s : error: vocab type is not DEEPSEEKCODER\n", __func__); fprintf(stderr, "%s : error: vocab type is not BPE\n", __func__);
llama_free_model(model); llama_free_model(model);
llama_free(ctx); llama_free(ctx);
return 2; return 2;

View file

@ -89,8 +89,8 @@ int main(int argc, char **argv) {
} }
} }
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_DEEPSEEKLLM) { if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
fprintf(stderr, "%s : error: vocab type is not DEEPSEEKLLM\n", __func__); fprintf(stderr, "%s : error: vocab type is not BPE\n", __func__);
llama_free_model(model); llama_free_model(model);
llama_free(ctx); llama_free(ctx);
return 2; return 2;

View file

@ -0,0 +1,197 @@
#include "llama.h"
#include "common.h"
#include "console.h"
#include <cstdio>
#include <string>
#include <map>
#include <vector>
#include <fstream>
// generate using test-tokenizer-0-llama.py
static const std::map<std::string, std::vector<llama_token>> & k_tests() {
static std::map<std::string, std::vector<llama_token>> _k_tests = {
{ "" , { }, },
{ " " , { 220, }, },
{ " " , { 256, }, },
{ " " , { 262, }, },
{ "\t" , { 197, }, },
{ "\n" , { 198, }, },
{ "\t\n" , { 1602, }, },
{ "Hello world" , { 9906, 1917, }, },
{ " Hello world" , { 22691, 1917, }, },
{ "Hello World" , { 9906, 4435, }, },
{ " Hello World" , { 22691, 4435, }, },
{ " Hello World!" , { 22691, 4435, 0, }, },
{ "Hello, world!" , { 9906, 11, 1917, 0, }, },
{ " Hello, world!" , { 22691, 11, 1917, 0, }, },
{ " this is 🦙.cpp" , { 420, 374, 11410, 99, 247, 13, 11055, }, },
{ "w048 7tuijk dsdfhu" , { 86, 23904, 220, 22, 83, 2005, 42908, 11729, 3013, 17156, }, },
{ "нещо на Български" , { 79862, 102118, 13373, 64571, 34694, 3114, 112203, 80112, }, },
{ "កាន់តែពិសេសអាចខលចេញ" , { 21549, 222, 98629, 241, 45358, 233, 21549, 237, 45358, 224, 21549, 244, 21549, 115, 21549, 253, 45358, 223, 21549, 253, 21549, 95, 98629, 227, 21549, 223, 21549, 249, 21549, 227, 45358, 223, 21549, 231, }, },
{ "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", { 9468, 248, 222, 320, 8416, 8, 27623, 114, 102470, 9468, 234, 104, 31643, 320, 36773, 100166, 98634, 8, 26602, 227, 320, 3323, 43465, 430, 706, 1202, 1866, 4037, 8, }, },
{ "Hello" , { 9906, }, },
{ " Hello" , { 22691, }, },
{ " Hello" , { 220, 22691, }, },
{ " Hello" , { 256, 22691, }, },
{ " Hello" , { 262, 22691, }, },
{ " Hello\n Hello" , { 262, 22691, 198, 262, 22691, }, },
{ " (" , { 320, }, },
{ "\n =" , { 198, 284, }, },
{ "' era" , { 6, 11639, }, },
{ "Hello, y'all! How are you 😁 ?我想在apple工作1314151天", { 9906, 11, 379, 65948, 0, 2650, 527, 499, 27623, 223, 949, 37046, 101067, 19000, 23182, 102301, 9263, 18136, 16, 36827, 21909, }, },
{ "3" , { 18, }, },
{ "33" , { 1644, }, },
{ "333" , { 8765, }, },
{ "3333" , { 8765, 18, }, },
{ "33333" , { 8765, 1644, }, },
{ "333333" , { 8765, 8765, }, },
{ "3333333" , { 8765, 8765, 18, }, },
{ "33333333" , { 8765, 8765, 1644, }, },
{ "333333333" , { 8765, 8765, 8765, }, },
};
return _k_tests;
}
int main(int argc, char **argv) {
if (argc < 2) {
fprintf(stderr, "Usage: %s vocab-file [text-file]\n", argv[0]);
return 1;
}
const std::string fname = argv[1];
std::string fname_text;
if (argc > 2) {
fname_text = argv[2];
}
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
llama_model * model;
llama_context * ctx;
llama_backend_init();
// load the vocab
{
auto mparams = llama_model_default_params();
mparams.vocab_only = true;
model = llama_load_model_from_file(fname.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
return 1;
}
auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
llama_free_model(model);
return 1;
}
}
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
fprintf(stderr, "%s : error: vocab type is not BPE\n", __func__);
llama_free_model(model);
llama_free(ctx);
return 2;
}
#ifdef _WIN32
// We need this for unicode console support
console::init(false, false);
atexit([]() { console::cleanup(); });
#endif
bool success = true;
for (const auto & test_kv : k_tests()) {
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, false);
printf("\n");
printf("src: '%s'\n", test_kv.first.c_str());
printf("res: '%s'\n", llama_detokenize_bpe(ctx, res).c_str());
printf("tok: ");
for (const auto & tok : res) {
printf("%d ", tok);
}
printf("\n");
bool correct = res.size() == test_kv.second.size();
for (int i = 0; i < (int) res.size() && correct; ++i) {
if (test_kv.second[i] != res[i]) {
correct = false;
}
}
if (!correct) {
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
llama_detokenize_bpe(ctx, res).c_str(),
llama_detokenize_bpe(ctx, test_kv.second).c_str());
fprintf(stderr, "%s : expected tokens: ", __func__);
for (const auto & t : test_kv.second) {
fprintf(stderr, "%6d, ", t);
}
fprintf(stderr, "\n");
fprintf(stderr, "%s : got tokens: ", __func__);
for (const auto & t : res) {
fprintf(stderr, "%6d, ", t);
}
fprintf(stderr, "\n");
success = false;
}
}
if (!fname_text.empty()) {
fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
std::string text;
{
std::ifstream ifs(fname_text);
if (!ifs) {
fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_text.c_str());
return 1;
}
text = std::string(std::istreambuf_iterator<char>(ifs), std::istreambuf_iterator<char>());
}
fprintf(stderr, "%s : text size: %zu\n", __func__, text.size());
const std::vector<llama_token> res = llama_tokenize(ctx, text, false);
fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size());
{
const std::string fname_out = fname_text + ".tokcpp";
std::ofstream ofs(fname_out);
if (!ofs) {
fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
return 1;
}
for (const auto & tok : res) {
ofs << tok << " '" << llama_detokenize_bpe(ctx, std::vector<int>{tok}) << "'" << std::endl;
}
}
fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
}
llama_free_model(model);
llama_free(ctx);
llama_backend_free();
return success ? 0 : 3;
}

View file

@ -126,7 +126,7 @@ int main(int argc, char **argv) {
} }
printf("\n"); printf("\n");
bool correct = res_nobos.size() == test_kv.second.size() && res_bos.size() == res_nobos.size() + 1 && res_bos[0] == 1; bool correct = res_nobos.size() == test_kv.second.size() && res_bos.size() == res_nobos.size() + 1 && res_bos[0] == llama_token_bos(model);
for (int i = 0; i < (int) res_nobos.size() && correct; ++i) { for (int i = 0; i < (int) res_nobos.size() && correct; ++i) {
if (test_kv.second[i] != res_bos[i + 1]) { if (test_kv.second[i] != res_bos[i + 1]) {