llama : towards llama3 tokenization support (wip)
This commit is contained in:
parent
ed42711b90
commit
4907e41aa7
8 changed files with 298 additions and 121 deletions
|
@ -215,50 +215,50 @@ class Model(ABC):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
|
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
|
||||||
|
|
||||||
@staticmethod
|
# @staticmethod
|
||||||
def from_model_architecture(model_architecture):
|
# def from_model_architecture(model_architecture):
|
||||||
if model_architecture == "GPTNeoXForCausalLM":
|
# if model_architecture == "GPTNeoXForCausalLM":
|
||||||
return GPTNeoXModel
|
# return GPTNeoXModel
|
||||||
if model_architecture == "BloomForCausalLM":
|
# if model_architecture == "BloomForCausalLM":
|
||||||
return BloomModel
|
# return BloomModel
|
||||||
if model_architecture == "MPTForCausalLM":
|
# if model_architecture == "MPTForCausalLM":
|
||||||
return MPTModel
|
# return MPTModel
|
||||||
if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
|
# if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
|
||||||
return BaichuanModel
|
# return BaichuanModel
|
||||||
if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
|
# if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
|
||||||
return FalconModel
|
# return FalconModel
|
||||||
if model_architecture == "GPTBigCodeForCausalLM":
|
# if model_architecture == "GPTBigCodeForCausalLM":
|
||||||
return StarCoderModel
|
# return StarCoderModel
|
||||||
if model_architecture == "GPTRefactForCausalLM":
|
# if model_architecture == "GPTRefactForCausalLM":
|
||||||
return RefactModel
|
# return RefactModel
|
||||||
if model_architecture == "PersimmonForCausalLM":
|
# if model_architecture == "PersimmonForCausalLM":
|
||||||
return PersimmonModel
|
# return PersimmonModel
|
||||||
if model_architecture == "LlamaForCausalLM":
|
# if model_architecture == "LlamaForCausalLM":
|
||||||
return DeepseekCoderModel
|
# return LlamaModel
|
||||||
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
# if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||||
return StableLMModel
|
# return StableLMModel
|
||||||
if model_architecture == "QWenLMHeadModel":
|
# if model_architecture == "QWenLMHeadModel":
|
||||||
return QwenModel
|
# return QwenModel
|
||||||
if model_architecture == "Qwen2ForCausalLM":
|
# if model_architecture == "Qwen2ForCausalLM":
|
||||||
return Model
|
# return Model
|
||||||
if model_architecture == "MixtralForCausalLM":
|
# if model_architecture == "MixtralForCausalLM":
|
||||||
return MixtralModel
|
# return MixtralModel
|
||||||
if model_architecture == "GPT2LMHeadModel":
|
# if model_architecture == "GPT2LMHeadModel":
|
||||||
return GPT2Model
|
# return GPT2Model
|
||||||
if model_architecture == "PhiForCausalLM":
|
# if model_architecture == "PhiForCausalLM":
|
||||||
return Phi2Model
|
# return Phi2Model
|
||||||
if model_architecture == "PlamoForCausalLM":
|
# if model_architecture == "PlamoForCausalLM":
|
||||||
return PlamoModel
|
# return PlamoModel
|
||||||
if model_architecture == "CodeShellForCausalLM":
|
# if model_architecture == "CodeShellForCausalLM":
|
||||||
return CodeShellModel
|
# return CodeShellModel
|
||||||
if model_architecture == "OrionForCausalLM":
|
# if model_architecture == "OrionForCausalLM":
|
||||||
return OrionModel
|
# return OrionModel
|
||||||
if model_architecture == "InternLM2ForCausalLM":
|
# if model_architecture == "InternLM2ForCausalLM":
|
||||||
return InternLM2Model
|
# return InternLM2Model
|
||||||
if model_architecture == "MiniCPMForCausalLM":
|
# if model_architecture == "MiniCPMForCausalLM":
|
||||||
return MiniCPMModel
|
# return MiniCPMModel
|
||||||
if model_architecture == "BertModel":
|
# if model_architecture == "BertModel":
|
||||||
return BertModel
|
# return BertModel
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_name(model_name: str):
|
def from_model_name(model_name: str):
|
||||||
|
@ -281,10 +281,8 @@ class Model(ABC):
|
||||||
return RefactModel
|
return RefactModel
|
||||||
if model_name_lower == "persimmon":
|
if model_name_lower == "persimmon":
|
||||||
return PersimmonModel
|
return PersimmonModel
|
||||||
if model_name_lower == "deepseekcoder":
|
if model_name_lower in ("llama", "deepseekcoder", "deepseekllm"):
|
||||||
return DeepseekCoderModel
|
return LlamaModel
|
||||||
if model_name_lower == "deepseekllm":
|
|
||||||
return DeepseekLLMModel
|
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
def _is_model_safetensors(self) -> bool:
|
def _is_model_safetensors(self) -> bool:
|
||||||
|
@ -376,7 +374,6 @@ class Model(ABC):
|
||||||
|
|
||||||
return tokens, toktypes
|
return tokens, toktypes
|
||||||
|
|
||||||
|
|
||||||
def _set_vocab_gpt2(self, tokenizer_model:str = "gpt2") -> None:
|
def _set_vocab_gpt2(self, tokenizer_model:str = "gpt2") -> None:
|
||||||
tokens, toktypes = self.get_basic_vocab()
|
tokens, toktypes = self.get_basic_vocab()
|
||||||
self.gguf_writer.add_tokenizer_model(tokenizer_model)
|
self.gguf_writer.add_tokenizer_model(tokenizer_model)
|
||||||
|
@ -1312,31 +1309,7 @@ class PersimmonModel(Model):
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
@Model.register("LlamaForCausalLM")
|
|
||||||
class DeepseekCoderModel(Model):
|
|
||||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
|
||||||
super().set_gguf_parameters()
|
|
||||||
head_count = self.hparams["num_attention_heads"]
|
|
||||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
|
||||||
self.gguf_writer.add_head_count(head_count)
|
|
||||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
|
||||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
|
||||||
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
|
|
||||||
|
|
||||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
|
||||||
if self.hparams["rope_scaling"].get("type") == "linear":
|
|
||||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
|
||||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
|
||||||
|
|
||||||
def set_vocab(self):
|
|
||||||
self._set_vocab_gpt2("deepseek_coder")
|
|
||||||
|
|
||||||
class DeepseekLLMModel(DeepseekCoderModel):
|
|
||||||
def set_vocab(self):
|
|
||||||
self._set_vocab_gpt2("deepseek_llm")
|
|
||||||
|
|
||||||
@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
|
@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
|
||||||
class StableLMModel(Model):
|
class StableLMModel(Model):
|
||||||
|
@ -1479,6 +1452,11 @@ class LlamaModel(Model):
|
||||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||||
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
|
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
|
||||||
|
|
||||||
|
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||||
|
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
||||||
|
|
||||||
# Same as super class, but permuting q_proj, k_proj
|
# Same as super class, but permuting q_proj, k_proj
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
||||||
|
|
75
llama.cpp
75
llama.cpp
|
@ -2114,6 +2114,7 @@ struct llama_vocab {
|
||||||
ttype type;
|
ttype type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum llm_arch arch = LLM_ARCH_UNKNOWN;
|
||||||
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||||
|
|
||||||
std::unordered_map<token, id> token_to_id;
|
std::unordered_map<token, id> token_to_id;
|
||||||
|
@ -4243,10 +4244,6 @@ static void llm_load_vocab(
|
||||||
} else {
|
} else {
|
||||||
if (tokenizer_name == "gpt2") {
|
if (tokenizer_name == "gpt2") {
|
||||||
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
||||||
} else if (tokenizer_name == "deepseek_coder") {
|
|
||||||
vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKCODER;
|
|
||||||
} else if (tokenizer_name == "deepseek_llm") {
|
|
||||||
vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKLLM;
|
|
||||||
} else {
|
} else {
|
||||||
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
|
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
|
||||||
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
|
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
|
||||||
|
@ -4287,6 +4284,8 @@ static void llm_load_vocab(
|
||||||
vocab.special_cls_id = -1;
|
vocab.special_cls_id = -1;
|
||||||
vocab.special_mask_id = -1;
|
vocab.special_mask_id = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
vocab.arch = model.arch;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
|
const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
|
||||||
|
@ -11784,10 +11783,9 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
||||||
auto buf = token_data.text.substr(3, 2);
|
auto buf = token_data.text.substr(3, 2);
|
||||||
return strtol(buf.c_str(), NULL, 16);
|
return strtol(buf.c_str(), NULL, 16);
|
||||||
}
|
}
|
||||||
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
|
|
||||||
case LLAMA_VOCAB_TYPE_BPE: {
|
case LLAMA_VOCAB_TYPE_BPE: {
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
return unicode_utf8_to_byte(token_data.text);
|
return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT?
|
||||||
}
|
}
|
||||||
case LLAMA_VOCAB_TYPE_WPM: {
|
case LLAMA_VOCAB_TYPE_WPM: {
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -11812,7 +11810,6 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
|
||||||
return vocab.token_to_id.at(buf2);
|
return vocab.token_to_id.at(buf2);
|
||||||
}
|
}
|
||||||
case LLAMA_VOCAB_TYPE_WPM:
|
case LLAMA_VOCAB_TYPE_WPM:
|
||||||
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
|
|
||||||
case LLAMA_VOCAB_TYPE_BPE: {
|
case LLAMA_VOCAB_TYPE_BPE: {
|
||||||
return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
|
return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
|
||||||
}
|
}
|
||||||
|
@ -12014,33 +12011,43 @@ struct llm_tokenizer_bpe {
|
||||||
std::vector<std::string> word_collection;
|
std::vector<std::string> word_collection;
|
||||||
switch (vocab.type) {
|
switch (vocab.type) {
|
||||||
case LLAMA_VOCAB_TYPE_BPE:
|
case LLAMA_VOCAB_TYPE_BPE:
|
||||||
word_collection = unicode_regex_split(text, {
|
switch (vocab.arch) {
|
||||||
"\\p{P}+",
|
// TODO: how to detect deepseek and llama v3 models?
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
//case LLM_ARCH_LLAMA:
|
||||||
"\\p{N}+",
|
//case LLM_ARCH_DEEPSEEK_CODER:
|
||||||
"[0-9][0-9][0-9]"
|
// word_collection = unicode_regex_split(text, {
|
||||||
});
|
// "[\r\n]",
|
||||||
break;
|
// "\\s?\\p{L}+",
|
||||||
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
|
// "\\s?\\p{P}+",
|
||||||
word_collection = unicode_regex_split(text, {
|
// "[一-龥ࠀ-一가-]+",
|
||||||
"[\r\n]",
|
// "\\p{N}+"
|
||||||
"\\s?\\p{L}+",
|
// });
|
||||||
"\\s?\\p{P}+",
|
// break;
|
||||||
"[一-龥ࠀ-一가-]+",
|
//case LLM_ARCH_DEEPSEEK_LLM:
|
||||||
"\\p{N}+"
|
// word_collection = unicode_regex_split(text, {
|
||||||
});
|
// "[\r\n]",
|
||||||
break;
|
// "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
|
||||||
case LLAMA_VOCAB_TYPE_DEEPSEEKLLM:
|
// "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
|
||||||
word_collection = unicode_regex_split(text, {
|
// "\\s+$",
|
||||||
"[\r\n]",
|
// "[一-龥ࠀ-一가-]+",
|
||||||
"\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
|
// "\\p{N}+"
|
||||||
"\\s?[!-/:-~!-/:-~‘-‟ -。]+",
|
// });
|
||||||
"\\s+$",
|
// break;
|
||||||
"[一-龥ࠀ-一가-]+",
|
default:
|
||||||
"\\p{N}+"
|
// default regex for BPE tokenization pre-processing
|
||||||
});
|
{
|
||||||
|
word_collection = unicode_regex_split(text, {
|
||||||
|
"\\p{P}+",
|
||||||
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
|
"\\p{N}+",
|
||||||
|
"[0-9][0-9][0-9]"
|
||||||
|
});
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12486,8 +12493,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||||
output.push_back(vocab.special_eos_id);
|
output.push_back(vocab.special_eos_id);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
|
|
||||||
case LLAMA_VOCAB_TYPE_DEEPSEEKLLM:
|
|
||||||
case LLAMA_VOCAB_TYPE_BPE:
|
case LLAMA_VOCAB_TYPE_BPE:
|
||||||
{
|
{
|
||||||
if (add_special && vocab.special_add_bos == 1) {
|
if (add_special && vocab.special_add_bos == 1) {
|
||||||
|
@ -17188,8 +17193,6 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
|
|
||||||
case LLAMA_VOCAB_TYPE_DEEPSEEKLLM:
|
|
||||||
case LLAMA_VOCAB_TYPE_BPE: {
|
case LLAMA_VOCAB_TYPE_BPE: {
|
||||||
// NOTE: we accept all unsupported token types,
|
// NOTE: we accept all unsupported token types,
|
||||||
// suppressing them like CONTROL tokens.
|
// suppressing them like CONTROL tokens.
|
||||||
|
|
2
llama.h
2
llama.h
|
@ -67,8 +67,6 @@ extern "C" {
|
||||||
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
|
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
|
||||||
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
||||||
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
||||||
LLAMA_VOCAB_TYPE_DEEPSEEKCODER = 4, // Deepseek Coder
|
|
||||||
LLAMA_VOCAB_TYPE_DEEPSEEKLLM = 5, // Deepseek LLM
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// note: these values should be synchronized with ggml_rope
|
// note: these values should be synchronized with ggml_rope
|
||||||
|
|
|
@ -41,12 +41,13 @@ llama_test(test-quantize-perf.cpp)
|
||||||
llama_test(test-sampling.cpp)
|
llama_test(test-sampling.cpp)
|
||||||
llama_test(test-chat-template.cpp)
|
llama_test(test-chat-template.cpp)
|
||||||
|
|
||||||
|
# TODO: tmp disabled LLaMA v3 and Deepseek tests
|
||||||
|
llama_test(test-tokenizer-0-llama.cpp NAME test-tokenizer-0-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
||||||
|
#llama_test(test-tokenizer-0-llama-v3.cpp NAME test-tokenizer-0-llama-v3 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-v3.gguf)
|
||||||
|
llama_test(test-tokenizer-0-falcon.cpp NAME test-tokenizer-0-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
|
||||||
|
|
||||||
llama_test(test-tokenizer-0-llama.cpp NAME test-tokenizer-0-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
#llama_test(test-tokenizer-0-deepseek-coder.cpp NAME test-tokenizer-0-deepseek-coder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-coder.gguf)
|
||||||
llama_test(test-tokenizer-0-falcon.cpp NAME test-tokenizer-0-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
|
#llama_test(test-tokenizer-0-deepseek-llm.cpp NAME test-tokenizer-0-deepseek-llm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-llm.gguf)
|
||||||
|
|
||||||
llama_test(test-tokenizer-0-deepseek-coder.cpp NAME test-tokenizer-0-deepseek-coder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-coder.gguf)
|
|
||||||
llama_test(test-tokenizer-0-deepseek-llm.cpp NAME test-tokenizer-0-deepseek-llm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-llm.gguf)
|
|
||||||
|
|
||||||
llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
||||||
llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
|
llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
|
||||||
|
|
|
@ -89,8 +89,8 @@ int main(int argc, char **argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_DEEPSEEKCODER) {
|
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
|
||||||
fprintf(stderr, "%s : error: vocab type is not DEEPSEEKCODER\n", __func__);
|
fprintf(stderr, "%s : error: vocab type is not BPE\n", __func__);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
return 2;
|
return 2;
|
||||||
|
|
|
@ -89,8 +89,8 @@ int main(int argc, char **argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_DEEPSEEKLLM) {
|
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
|
||||||
fprintf(stderr, "%s : error: vocab type is not DEEPSEEKLLM\n", __func__);
|
fprintf(stderr, "%s : error: vocab type is not BPE\n", __func__);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
return 2;
|
return 2;
|
||||||
|
|
197
tests/test-tokenizer-0-llama-v3.cpp
Normal file
197
tests/test-tokenizer-0-llama-v3.cpp
Normal file
|
@ -0,0 +1,197 @@
|
||||||
|
#include "llama.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "console.h"
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
// generate using test-tokenizer-0-llama.py
|
||||||
|
static const std::map<std::string, std::vector<llama_token>> & k_tests() {
|
||||||
|
static std::map<std::string, std::vector<llama_token>> _k_tests = {
|
||||||
|
{ "" , { }, },
|
||||||
|
{ " " , { 220, }, },
|
||||||
|
{ " " , { 256, }, },
|
||||||
|
{ " " , { 262, }, },
|
||||||
|
{ "\t" , { 197, }, },
|
||||||
|
{ "\n" , { 198, }, },
|
||||||
|
{ "\t\n" , { 1602, }, },
|
||||||
|
{ "Hello world" , { 9906, 1917, }, },
|
||||||
|
{ " Hello world" , { 22691, 1917, }, },
|
||||||
|
{ "Hello World" , { 9906, 4435, }, },
|
||||||
|
{ " Hello World" , { 22691, 4435, }, },
|
||||||
|
{ " Hello World!" , { 22691, 4435, 0, }, },
|
||||||
|
{ "Hello, world!" , { 9906, 11, 1917, 0, }, },
|
||||||
|
{ " Hello, world!" , { 22691, 11, 1917, 0, }, },
|
||||||
|
{ " this is 🦙.cpp" , { 420, 374, 11410, 99, 247, 13, 11055, }, },
|
||||||
|
{ "w048 7tuijk dsdfhu" , { 86, 23904, 220, 22, 83, 2005, 42908, 11729, 3013, 17156, }, },
|
||||||
|
{ "нещо на Български" , { 79862, 102118, 13373, 64571, 34694, 3114, 112203, 80112, }, },
|
||||||
|
{ "កាន់តែពិសេសអាចខលចេញ" , { 21549, 222, 98629, 241, 45358, 233, 21549, 237, 45358, 224, 21549, 244, 21549, 115, 21549, 253, 45358, 223, 21549, 253, 21549, 95, 98629, 227, 21549, 223, 21549, 249, 21549, 227, 45358, 223, 21549, 231, }, },
|
||||||
|
{ "🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", { 9468, 248, 222, 320, 8416, 8, 27623, 114, 102470, 9468, 234, 104, 31643, 320, 36773, 100166, 98634, 8, 26602, 227, 320, 3323, 43465, 430, 706, 1202, 1866, 4037, 8, }, },
|
||||||
|
{ "Hello" , { 9906, }, },
|
||||||
|
{ " Hello" , { 22691, }, },
|
||||||
|
{ " Hello" , { 220, 22691, }, },
|
||||||
|
{ " Hello" , { 256, 22691, }, },
|
||||||
|
{ " Hello" , { 262, 22691, }, },
|
||||||
|
{ " Hello\n Hello" , { 262, 22691, 198, 262, 22691, }, },
|
||||||
|
{ " (" , { 320, }, },
|
||||||
|
{ "\n =" , { 198, 284, }, },
|
||||||
|
{ "' era" , { 6, 11639, }, },
|
||||||
|
{ "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~", { 9906, 11, 379, 65948, 0, 2650, 527, 499, 27623, 223, 949, 37046, 101067, 19000, 23182, 102301, 9263, 18136, 16, 36827, 21909, }, },
|
||||||
|
{ "3" , { 18, }, },
|
||||||
|
{ "33" , { 1644, }, },
|
||||||
|
{ "333" , { 8765, }, },
|
||||||
|
{ "3333" , { 8765, 18, }, },
|
||||||
|
{ "33333" , { 8765, 1644, }, },
|
||||||
|
{ "333333" , { 8765, 8765, }, },
|
||||||
|
{ "3333333" , { 8765, 8765, 18, }, },
|
||||||
|
{ "33333333" , { 8765, 8765, 1644, }, },
|
||||||
|
{ "333333333" , { 8765, 8765, 8765, }, },
|
||||||
|
};
|
||||||
|
|
||||||
|
return _k_tests;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
if (argc < 2) {
|
||||||
|
fprintf(stderr, "Usage: %s vocab-file [text-file]\n", argv[0]);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string fname = argv[1];
|
||||||
|
|
||||||
|
std::string fname_text;
|
||||||
|
if (argc > 2) {
|
||||||
|
fname_text = argv[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
|
||||||
|
|
||||||
|
llama_model * model;
|
||||||
|
llama_context * ctx;
|
||||||
|
|
||||||
|
llama_backend_init();
|
||||||
|
|
||||||
|
// load the vocab
|
||||||
|
{
|
||||||
|
auto mparams = llama_model_default_params();
|
||||||
|
|
||||||
|
mparams.vocab_only = true;
|
||||||
|
|
||||||
|
model = llama_load_model_from_file(fname.c_str(), mparams);
|
||||||
|
|
||||||
|
if (model == NULL) {
|
||||||
|
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cparams = llama_context_default_params();
|
||||||
|
|
||||||
|
ctx = llama_new_context_with_model(model, cparams);
|
||||||
|
|
||||||
|
if (ctx == NULL) {
|
||||||
|
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
|
||||||
|
llama_free_model(model);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
|
||||||
|
fprintf(stderr, "%s : error: vocab type is not BPE\n", __func__);
|
||||||
|
llama_free_model(model);
|
||||||
|
llama_free(ctx);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
// We need this for unicode console support
|
||||||
|
console::init(false, false);
|
||||||
|
atexit([]() { console::cleanup(); });
|
||||||
|
#endif
|
||||||
|
|
||||||
|
bool success = true;
|
||||||
|
|
||||||
|
for (const auto & test_kv : k_tests()) {
|
||||||
|
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, false);
|
||||||
|
|
||||||
|
printf("\n");
|
||||||
|
printf("src: '%s'\n", test_kv.first.c_str());
|
||||||
|
printf("res: '%s'\n", llama_detokenize_bpe(ctx, res).c_str());
|
||||||
|
printf("tok: ");
|
||||||
|
for (const auto & tok : res) {
|
||||||
|
printf("%d ", tok);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
|
bool correct = res.size() == test_kv.second.size();
|
||||||
|
for (int i = 0; i < (int) res.size() && correct; ++i) {
|
||||||
|
if (test_kv.second[i] != res[i]) {
|
||||||
|
correct = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!correct) {
|
||||||
|
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
|
||||||
|
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
|
||||||
|
llama_detokenize_bpe(ctx, res).c_str(),
|
||||||
|
llama_detokenize_bpe(ctx, test_kv.second).c_str());
|
||||||
|
fprintf(stderr, "%s : expected tokens: ", __func__);
|
||||||
|
for (const auto & t : test_kv.second) {
|
||||||
|
fprintf(stderr, "%6d, ", t);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
fprintf(stderr, "%s : got tokens: ", __func__);
|
||||||
|
for (const auto & t : res) {
|
||||||
|
fprintf(stderr, "%6d, ", t);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
|
success = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!fname_text.empty()) {
|
||||||
|
fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
|
||||||
|
|
||||||
|
std::string text;
|
||||||
|
{
|
||||||
|
std::ifstream ifs(fname_text);
|
||||||
|
if (!ifs) {
|
||||||
|
fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_text.c_str());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
text = std::string(std::istreambuf_iterator<char>(ifs), std::istreambuf_iterator<char>());
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s : text size: %zu\n", __func__, text.size());
|
||||||
|
|
||||||
|
const std::vector<llama_token> res = llama_tokenize(ctx, text, false);
|
||||||
|
|
||||||
|
fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size());
|
||||||
|
|
||||||
|
{
|
||||||
|
const std::string fname_out = fname_text + ".tokcpp";
|
||||||
|
|
||||||
|
std::ofstream ofs(fname_out);
|
||||||
|
if (!ofs) {
|
||||||
|
fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & tok : res) {
|
||||||
|
ofs << tok << " '" << llama_detokenize_bpe(ctx, std::vector<int>{tok}) << "'" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_free_model(model);
|
||||||
|
llama_free(ctx);
|
||||||
|
|
||||||
|
llama_backend_free();
|
||||||
|
|
||||||
|
return success ? 0 : 3;
|
||||||
|
}
|
|
@ -126,7 +126,7 @@ int main(int argc, char **argv) {
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
bool correct = res_nobos.size() == test_kv.second.size() && res_bos.size() == res_nobos.size() + 1 && res_bos[0] == 1;
|
bool correct = res_nobos.size() == test_kv.second.size() && res_bos.size() == res_nobos.size() + 1 && res_bos[0] == llama_token_bos(model);
|
||||||
|
|
||||||
for (int i = 0; i < (int) res_nobos.size() && correct; ++i) {
|
for (int i = 0; i < (int) res_nobos.size() && correct; ++i) {
|
||||||
if (test_kv.second[i] != res_bos[i + 1]) {
|
if (test_kv.second[i] != res_bos[i + 1]) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue