merged the changes from deepseeker models to main branch
This commit is contained in:
parent
83b72cb086
commit
6fbab2dbc8
15 changed files with 886 additions and 151 deletions
17
.gitignore
vendored
17
.gitignore
vendored
|
@ -108,3 +108,20 @@ examples/server/*.mjs.hpp
|
|||
poetry.lock
|
||||
poetry.toml
|
||||
nppBackup
|
||||
|
||||
# Test binaries
|
||||
/tests/test-grammar-parser
|
||||
/tests/test-llama-grammar
|
||||
/tests/test-double-float
|
||||
/tests/test-grad0
|
||||
/tests/test-opt
|
||||
/tests/test-quantize-fns
|
||||
/tests/test-quantize-perf
|
||||
/tests/test-sampling
|
||||
/tests/test-tokenizer-0-llama
|
||||
/tests/test-tokenizer-0-falcon
|
||||
/tests/test-tokenizer-0-deepseek-coder
|
||||
/tests/test-tokenizer-1-llama
|
||||
/tests/test-tokenizer-1-bpe
|
||||
/tests/test-rope
|
||||
/tests/test-backend-ops
|
||||
|
|
13
Makefile
13
Makefile
|
@ -6,7 +6,8 @@ BUILD_TARGETS = \
|
|||
|
||||
# Binaries only useful for tests
|
||||
TEST_TARGETS = \
|
||||
tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \
|
||||
tests/test-llama-grammar tests/test-tokenizer-0-deepseek-coder tests/test-tokenizer-0-deepseek-llm \
|
||||
tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \
|
||||
tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \
|
||||
tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope \
|
||||
tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease \
|
||||
|
@ -53,6 +54,10 @@ test: $(TEST_TARGETS)
|
|||
./$$test_target $(CURDIR)/models/ggml-vocab-llama.gguf; \
|
||||
elif [ "$$test_target" = "tests/test-tokenizer-0-falcon" ]; then \
|
||||
./$$test_target $(CURDIR)/models/ggml-vocab-falcon.gguf; \
|
||||
elif [ "$$test_target" = "tests/test-tokenizer-0-deepseek-coder" ]; then \
|
||||
./$$test_target $(CURDIR)/models/ggml-vocab-deepseek-coder.gguf; \
|
||||
elif [ "$$test_target" = "tests/test-tokenizer-0-deepseek-llm" ]; then \
|
||||
./$$test_target $(CURDIR)/models/ggml-vocab-deepseek-llm.gguf; \
|
||||
elif [ "$$test_target" = "tests/test-tokenizer-1-llama" ]; then \
|
||||
continue; \
|
||||
elif [ "$$test_target" = "tests/test-tokenizer-1-bpe" ]; then \
|
||||
|
@ -979,6 +984,12 @@ tests/test-tokenizer-0-llama: tests/test-tokenizer-0-llama.cpp ggml.o llama.o $(
|
|||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-tokenizer-0-deepseek-coder: tests/test-tokenizer-0-deepseek-coder.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-tokenizer-0-deepseek-llm: tests/test-tokenizer-0-deepseek-llm.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-tokenizer-1-bpe: tests/test-tokenizer-1-bpe.cpp ggml.o llama.o $(COMMON_DEPS) console.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
|
|
@ -215,6 +215,78 @@ class Model(ABC):
|
|||
except KeyError:
|
||||
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
|
||||
|
||||
@staticmethod
|
||||
def from_model_architecture(model_architecture):
|
||||
if model_architecture == "GPTNeoXForCausalLM":
|
||||
return GPTNeoXModel
|
||||
if model_architecture == "BloomForCausalLM":
|
||||
return BloomModel
|
||||
if model_architecture == "MPTForCausalLM":
|
||||
return MPTModel
|
||||
if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
|
||||
return BaichuanModel
|
||||
if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
|
||||
return FalconModel
|
||||
if model_architecture == "GPTBigCodeForCausalLM":
|
||||
return StarCoderModel
|
||||
if model_architecture == "GPTRefactForCausalLM":
|
||||
return RefactModel
|
||||
if model_architecture == "PersimmonForCausalLM":
|
||||
return PersimmonModel
|
||||
if model_architecture == "LlamaForCausalLM":
|
||||
return DeepseekCoderModel
|
||||
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||
return StableLMModel
|
||||
if model_architecture == "QWenLMHeadModel":
|
||||
return QwenModel
|
||||
if model_architecture == "Qwen2ForCausalLM":
|
||||
return Model
|
||||
if model_architecture == "MixtralForCausalLM":
|
||||
return MixtralModel
|
||||
if model_architecture == "GPT2LMHeadModel":
|
||||
return GPT2Model
|
||||
if model_architecture == "PhiForCausalLM":
|
||||
return Phi2Model
|
||||
if model_architecture == "PlamoForCausalLM":
|
||||
return PlamoModel
|
||||
if model_architecture == "CodeShellForCausalLM":
|
||||
return CodeShellModel
|
||||
if model_architecture == "OrionForCausalLM":
|
||||
return OrionModel
|
||||
if model_architecture == "InternLM2ForCausalLM":
|
||||
return InternLM2Model
|
||||
if model_architecture == "MiniCPMForCausalLM":
|
||||
return MiniCPMModel
|
||||
if model_architecture == "BertModel":
|
||||
return BertModel
|
||||
|
||||
@staticmethod
|
||||
def from_model_name(model_name: str):
|
||||
model_name_lower = model_name.lower()
|
||||
if model_name_lower in ("stablelmepoch", "llavastablelmepoch"):
|
||||
return StableLMModel
|
||||
if model_name_lower == "gptneox":
|
||||
return GPTNeoXModel
|
||||
if model_name_lower == "bloom":
|
||||
return BloomModel
|
||||
if model_name_lower == "mpt":
|
||||
return MPTModel
|
||||
if model_name_lower in ("baichuan"):
|
||||
return BaichuanModel
|
||||
if model_name_lower in ("falcon", "rw"):
|
||||
return FalconModel
|
||||
if model_name_lower == "gptbigcode":
|
||||
return StarCoderModel
|
||||
if model_name_lower == "gptrefact":
|
||||
return RefactModel
|
||||
if model_name_lower == "persimmon":
|
||||
return PersimmonModel
|
||||
if model_name_lower == "deepseekcoder":
|
||||
return DeepseekCoderModel
|
||||
if model_name_lower == "deepseekllm":
|
||||
return DeepseekLLMModel
|
||||
return Model
|
||||
|
||||
def _is_model_safetensors(self) -> bool:
|
||||
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
|
||||
|
||||
|
@ -228,6 +300,53 @@ class Model(ABC):
|
|||
return ("pytorch_model.bin",)
|
||||
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
|
||||
|
||||
def _get_model_architecture(self) -> gguf.MODEL_ARCH:
|
||||
arch = self.hparams["architectures"][0]
|
||||
if arch == "GPTNeoXForCausalLM":
|
||||
return gguf.MODEL_ARCH.GPTNEOX
|
||||
if arch == "BloomForCausalLM":
|
||||
return gguf.MODEL_ARCH.BLOOM
|
||||
if arch == "MPTForCausalLM":
|
||||
return gguf.MODEL_ARCH.MPT
|
||||
if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
|
||||
return gguf.MODEL_ARCH.BAICHUAN
|
||||
if arch in ("FalconForCausalLM", "RWForCausalLM"):
|
||||
return gguf.MODEL_ARCH.FALCON
|
||||
if arch == "GPTBigCodeForCausalLM":
|
||||
return gguf.MODEL_ARCH.STARCODER
|
||||
if arch == "GPTRefactForCausalLM":
|
||||
return gguf.MODEL_ARCH.REFACT
|
||||
if arch == "PersimmonForCausalLM":
|
||||
return gguf.MODEL_ARCH.PERSIMMON
|
||||
if arch == "LlamaForCausalLM":
|
||||
return gguf.MODEL_ARCH.LLAMA
|
||||
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||
return gguf.MODEL_ARCH.STABLELM
|
||||
if arch == "QWenLMHeadModel":
|
||||
return gguf.MODEL_ARCH.QWEN
|
||||
if arch == "Qwen2ForCausalLM":
|
||||
return gguf.MODEL_ARCH.QWEN2
|
||||
if arch == "MixtralForCausalLM":
|
||||
return gguf.MODEL_ARCH.LLAMA
|
||||
if arch == "GPT2LMHeadModel":
|
||||
return gguf.MODEL_ARCH.GPT2
|
||||
if arch == "PhiForCausalLM":
|
||||
return gguf.MODEL_ARCH.PHI2
|
||||
if arch == "PlamoForCausalLM":
|
||||
return gguf.MODEL_ARCH.PLAMO
|
||||
if arch == "CodeShellForCausalLM":
|
||||
return gguf.MODEL_ARCH.CODESHELL
|
||||
if arch == "OrionForCausalLM":
|
||||
return gguf.MODEL_ARCH.ORION
|
||||
if arch == "InternLM2ForCausalLM":
|
||||
return gguf.MODEL_ARCH.INTERNLM2
|
||||
if arch == "MiniCPMForCausalLM":
|
||||
return gguf.MODEL_ARCH.MINICPM
|
||||
if arch == "BertModel":
|
||||
return gguf.MODEL_ARCH.BERT
|
||||
|
||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||
|
||||
# used for GPT-2 BPE and WordPiece vocabs
|
||||
def get_basic_vocab(self) -> tuple[list[str], list[int]]:
|
||||
tokens: list[str] = []
|
||||
|
@ -257,9 +376,10 @@ class Model(ABC):
|
|||
|
||||
return tokens, toktypes
|
||||
|
||||
def _set_vocab_gpt2(self) -> None:
|
||||
|
||||
def _set_vocab_gpt2(self, tokenizer_model:str = "gpt2") -> None:
|
||||
tokens, toktypes = self.get_basic_vocab()
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_model(tokenizer_model)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
|
@ -1192,7 +1312,29 @@ class PersimmonModel(Model):
|
|||
n_dims = len(data.shape)
|
||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
@Model.register("LlamaForCausalLM")
|
||||
class DeepseekCoderModel(Model):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
|
||||
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2("deepseek_coder")
|
||||
|
||||
class DeepseekLLMModel(DeepseekCoderModel):
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2("deepseek_llm")
|
||||
|
||||
@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
|
||||
class StableLMModel(Model):
|
||||
|
@ -2843,6 +2985,7 @@ def parse_args() -> argparse.Namespace:
|
|||
help="directory containing model file",
|
||||
)
|
||||
parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)")
|
||||
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
233
llama.cpp
233
llama.cpp
|
@ -4228,9 +4228,19 @@ static void llm_load_vocab(
|
|||
if (add_space_prefix_keyidx != -1) {
|
||||
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
|
||||
} // The default value of add_space_prefix is true.
|
||||
} else if (tokenizer_name == "gpt2") {
|
||||
} else {
|
||||
if (tokenizer_name == "gpt2") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
||||
|
||||
} else if (tokenizer_name == "deepseek_coder") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKCODER;
|
||||
} else if (tokenizer_name == "deepseek_llm") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKLLM;
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
|
||||
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
|
||||
vocab.type = LLAMA_VOCAB_TYPE_SPM;
|
||||
return;
|
||||
}
|
||||
// read bpe merges and populate bpe ranks
|
||||
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
|
||||
if (merges_keyidx == -1) {
|
||||
|
@ -11779,6 +11789,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
|||
auto buf = token_data.text.substr(3, 2);
|
||||
return strtol(buf.c_str(), NULL, 16);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
GGML_ASSERT(false);
|
||||
return unicode_utf8_to_byte(token_data.text);
|
||||
|
@ -11806,6 +11817,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
|
|||
return vocab.token_to_id.at(buf2);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_WPM:
|
||||
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
|
||||
}
|
||||
|
@ -12003,7 +12015,21 @@ struct llm_tokenizer_bpe {
|
|||
|
||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||
int final_prev_index = -1;
|
||||
auto word_collection = bpe_gpt2_preprocess(text);
|
||||
|
||||
std::vector<std::string> word_collection;
|
||||
switch (vocab.type) {
|
||||
case LLAMA_VOCAB_TYPE_BPE:
|
||||
word_collection = bpe_gpt2_preprocess(text);
|
||||
break;
|
||||
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
|
||||
word_collection = bpe_deepseek_coder_preprocess(text);
|
||||
break;
|
||||
case LLAMA_VOCAB_TYPE_DEEPSEEKLLM:
|
||||
word_collection = bpe_deepseek_llm_preprocess(text);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
symbols_final.clear();
|
||||
|
||||
|
@ -12130,145 +12156,84 @@ private:
|
|||
work_queue.push(bigram);
|
||||
}
|
||||
|
||||
std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
|
||||
std::vector<std::string> bpe_words;
|
||||
std::vector<std::string> byte_encoding_process(const std::vector<std::string> & bpe_words) {
|
||||
std::vector<std::string>bpe_encoded_words;
|
||||
for (auto word : bpe_words) {
|
||||
std::string text_utf = "";
|
||||
auto utf_word = unicode_cpts_from_utf8(word);
|
||||
for (size_t i = 0; i < utf_word.size(); ++i)
|
||||
text_utf += unicode_cpt_to_utf8(utf_word[i]);
|
||||
|
||||
std::string token = "";
|
||||
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
|
||||
bool collecting_numeric = false;
|
||||
bool collecting_letter = false;
|
||||
bool collecting_special = false;
|
||||
bool collecting_whitespace_lookahead = false;
|
||||
bool collecting = false;
|
||||
|
||||
std::vector<std::string> text_utf;
|
||||
text_utf.reserve(text.size());
|
||||
bpe_words.reserve(text.size());
|
||||
bpe_encoded_words.reserve(text.size());
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
for (size_t i = 0; i < cpts.size(); ++i)
|
||||
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
|
||||
|
||||
for (int i = 0; i < (int)text_utf.size(); i++) {
|
||||
const std::string & utf_char = text_utf[i];
|
||||
bool split_condition = false;
|
||||
int bytes_remain = text_utf.size() - i;
|
||||
// forward backward lookups
|
||||
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
|
||||
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
|
||||
|
||||
// handling contractions
|
||||
if (!split_condition && bytes_remain >= 2) {
|
||||
// 's|'t|'m|'d
|
||||
if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
|
||||
split_condition = true;
|
||||
}
|
||||
if (split_condition) {
|
||||
if (token.size()) {
|
||||
bpe_words.emplace_back(token); // push previous content as token
|
||||
}
|
||||
token = utf_char + utf_char_next;
|
||||
bpe_words.emplace_back(token);
|
||||
token = "";
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (!split_condition && bytes_remain >= 3) {
|
||||
// 're|'ve|'ll
|
||||
if (utf_char == "\'" && (
|
||||
(utf_char_next == "r" && utf_char_next_next == "e") ||
|
||||
(utf_char_next == "v" && utf_char_next_next == "e") ||
|
||||
(utf_char_next == "l" && utf_char_next_next == "l"))
|
||||
) {
|
||||
split_condition = true;
|
||||
}
|
||||
if (split_condition) {
|
||||
// current token + next token can be defined
|
||||
if (token.size()) {
|
||||
bpe_words.emplace_back(token); // push previous content as token
|
||||
}
|
||||
token = utf_char + utf_char_next + utf_char_next_next;
|
||||
bpe_words.emplace_back(token); // the contraction
|
||||
token = "";
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (!split_condition && !collecting) {
|
||||
if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
|
||||
collecting_letter = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
||||
collecting_numeric = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (
|
||||
((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
|
||||
(!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
|
||||
) {
|
||||
collecting_special = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
|
||||
collecting_whitespace_lookahead = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
|
||||
split_condition = true;
|
||||
}
|
||||
}
|
||||
else if (!split_condition && collecting) {
|
||||
if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) {
|
||||
split_condition = true;
|
||||
}
|
||||
else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
|
||||
split_condition = true;
|
||||
}
|
||||
else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
|
||||
split_condition = true;
|
||||
}
|
||||
else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
||||
split_condition = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (utf_char_next == "") {
|
||||
split_condition = true; // final
|
||||
token += utf_char;
|
||||
}
|
||||
|
||||
if (split_condition) {
|
||||
if (token.size()) {
|
||||
bpe_words.emplace_back(token);
|
||||
}
|
||||
token = utf_char;
|
||||
collecting = false;
|
||||
collecting_letter = false;
|
||||
collecting_numeric = false;
|
||||
collecting_special = false;
|
||||
collecting_whitespace_lookahead = false;
|
||||
}
|
||||
else {
|
||||
token += utf_char;
|
||||
}
|
||||
}
|
||||
|
||||
for (std::string & word : bpe_words) {
|
||||
std::string encoded_token = "";
|
||||
for (char & c : word) {
|
||||
for (char & c : text_utf) {
|
||||
encoded_token += unicode_byte_to_utf8(c);
|
||||
}
|
||||
bpe_encoded_words.emplace_back(encoded_token);
|
||||
}
|
||||
|
||||
return bpe_encoded_words;
|
||||
}
|
||||
|
||||
std::vector<size_t> regex_preprocess(const std::wstring & text, const std::vector<size_t> & offsets, const std::wstring & regex_expr) {
|
||||
std::wregex expr(regex_expr);
|
||||
std::vector<size_t> bpe_words; // stroe the offset of each word
|
||||
bpe_words.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
std::wcregex_iterator it(text.data() + start, text.data() + start + offset, expr);
|
||||
std::wcregex_iterator end;
|
||||
|
||||
int64_t start_idx = 0;
|
||||
while (it != end) {
|
||||
std::wcmatch match = *it;
|
||||
if (match.position() > start_idx) {
|
||||
bpe_words.emplace_back(match.position() - start_idx);
|
||||
}
|
||||
bpe_words.emplace_back(match.length());
|
||||
start_idx = match.position() + match.length();
|
||||
++it;
|
||||
}
|
||||
|
||||
if (start_idx < (int64_t) offset) {
|
||||
bpe_words.emplace_back(offset - start_idx);
|
||||
}
|
||||
start += offset;
|
||||
}
|
||||
|
||||
return bpe_words;
|
||||
}
|
||||
|
||||
std::vector<std::string> regex_bpe_preprocess(const std::string & text, const std::vector<std::wstring> & regex_exprs) {
|
||||
std::wstring wtext = from_utf8(text);
|
||||
|
||||
std::vector<size_t> bpe_offsets = {wtext.size()};
|
||||
|
||||
for(auto & regex_expr : regex_exprs) {
|
||||
bpe_offsets = regex_preprocess(wtext, bpe_offsets, regex_expr);
|
||||
}
|
||||
|
||||
std::vector<std::string> bpe_words;
|
||||
bpe_words.reserve(bpe_offsets.size()); // Reserve memory for the approximate size
|
||||
size_t start = 0;
|
||||
for(size_t & offset : bpe_offsets){
|
||||
bpe_words.emplace_back(to_utf8(std::wstring(wtext, start, offset)));
|
||||
start += offset;
|
||||
}
|
||||
|
||||
return byte_encoding_process(bpe_words);
|
||||
}
|
||||
|
||||
std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
|
||||
return regex_bpe_preprocess(text, gpt2_regex);
|
||||
}
|
||||
|
||||
std::vector<std::string> bpe_deepseek_coder_preprocess(const std::string & text) {
|
||||
return regex_bpe_preprocess(text, deepseek_coder_regex);
|
||||
}
|
||||
|
||||
std::vector<std::string> bpe_deepseek_llm_preprocess(const std::string & text) {
|
||||
return regex_bpe_preprocess(text, deepseek_llm_regex);
|
||||
}
|
||||
|
||||
const llama_vocab & vocab;
|
||||
|
||||
std::vector<llm_symbol> symbols;
|
||||
|
@ -12586,6 +12551,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||
output.push_back(vocab.special_eos_id);
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
|
||||
case LLAMA_VOCAB_TYPE_DEEPSEEKLLM:
|
||||
case LLAMA_VOCAB_TYPE_BPE:
|
||||
{
|
||||
if (add_special && vocab.special_add_bos == 1) {
|
||||
|
@ -17286,6 +17253,8 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
|
|||
}
|
||||
break;
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_DEEPSEEKCODER:
|
||||
case LLAMA_VOCAB_TYPE_DEEPSEEKLLM:
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
// NOTE: we accept all unsupported token types,
|
||||
// suppressing them like CONTROL tokens.
|
||||
|
|
2
llama.h
2
llama.h
|
@ -67,6 +67,8 @@ extern "C" {
|
|||
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
|
||||
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
||||
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
||||
LLAMA_VOCAB_TYPE_DEEPSEEKCODER = 4, // Deepseek Coder
|
||||
LLAMA_VOCAB_TYPE_DEEPSEEKLLM = 5, // Deepseek LLM
|
||||
};
|
||||
|
||||
// note: these values should be synchronized with ggml_rope
|
||||
|
|
BIN
models/ggml-vocab-deepseek-coder.gguf
Normal file
BIN
models/ggml-vocab-deepseek-coder.gguf
Normal file
Binary file not shown.
BIN
models/ggml-vocab-deepseek-llm.gguf
Normal file
BIN
models/ggml-vocab-deepseek-llm.gguf
Normal file
Binary file not shown.
|
@ -41,9 +41,12 @@ llama_test(test-quantize-perf.cpp)
|
|||
llama_test(test-sampling.cpp)
|
||||
llama_test(test-chat-template.cpp)
|
||||
|
||||
|
||||
llama_test(test-tokenizer-0-llama.cpp NAME test-tokenizer-0-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
||||
llama_test(test-tokenizer-0-falcon.cpp NAME test-tokenizer-0-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
|
||||
|
||||
llama_test(test-tokenizer-0-deepseek-coder.cpp NAME test-tokenizer-0-deepseek-coder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-coder.gguf)
|
||||
|
||||
llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
||||
llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
|
||||
|
||||
|
|
188
tests/test-tokenizer-0-deepseek-coder.cpp
Normal file
188
tests/test-tokenizer-0-deepseek-coder.cpp
Normal file
|
@ -0,0 +1,188 @@
|
|||
#include "llama.h"
|
||||
#include "common.h"
|
||||
#include "console.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
||||
// generate using test-tokenizer-0-falcon.py
|
||||
static const std::map<std::string, std::vector<llama_token>> & k_tests() {
|
||||
static std::map<std::string, std::vector<llama_token>> _k_tests = {
|
||||
{ "" , { }, },
|
||||
{ " " , { 207, }, },
|
||||
{ " " , { 243, }, },
|
||||
{ " " , { 315, }, },
|
||||
{ "\t" , { 184, }, },
|
||||
{ "\n" , { 185, }, },
|
||||
{ "\t\n" , { 184, 185, }, },
|
||||
{ "Hello world" , { 17535, 1835, }, },
|
||||
{ " Hello world" , { 414, 9489, 1835, }, },
|
||||
{ "Hello World" , { 17535, 5414, }, },
|
||||
{ " Hello World" , { 414, 9489, 5414, }, },
|
||||
{ " Hello World!" , { 414, 9489, 5414, 0, }, },
|
||||
{ "Hello, world!" , { 17535, 11, 1835, 0, }, },
|
||||
{ " Hello, world!" , { 414, 9489, 11, 1835, 0, }, },
|
||||
{ " this is 🦙.cpp" , { 437, 317, 12394, 99, 234, 13, 14789, }, },
|
||||
{ "w048 7tuijk dsdfhu" , { 86, 15, 19, 23, 207, 22, 83, 3963, 27659, 26078, 3934, 14072, }, },
|
||||
{ "нещо на Български" , { 1593, 6478, 616, 2251, 14994, }, },
|
||||
{ "កាន់តែពិសេសអាចខលចេញ" , { 155, 239, 209, 155, 239, 114, 155, 239, 228, 155, 240, 220, 155, 239, 224, 155, 240, 211, 155, 239, 231, 155, 239, 115, 155, 239, 240, 155, 240, 210, 155, 239, 240, 155, 239, 95, 155, 239, 114, 155, 239, 214, 155, 239, 210, 155, 239, 236, 155, 239, 214, 155, 240, 210, 155, 239, 218, }, },
|
||||
{ "🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", { 10047, 235, 209, 334, 8760, 8, 12394, 233, 114, 350, 222, 10047, 221, 104, 169, 116, 224, 334, 4684, 3909, 992, 24330, 262, 29651, 612, 8, 207, 156, 237, 214, 334, 5950, 992, 78, 12896, 344, 638, 891, 1372, 10736, 8, }, },
|
||||
{ "Hello" , { 17535, }, },
|
||||
{ " Hello" , { 414, 9489, }, },
|
||||
{ " Hello" , { 207, 414, 9489, }, },
|
||||
{ " Hello" , { 243, 414, 9489, }, },
|
||||
{ " Hello" , { 315, 414, 9489, }, },
|
||||
{ " Hello\n Hello" , { 315, 414, 9489, 185, 315, 414, 9489, }, },
|
||||
{ "\n =" , { 185, 405, }, },
|
||||
{ "' era" , { 6, 2895, }, },
|
||||
{ "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~", { 17535, 11, 320, 6, 435, 0, 1717, 417, 340, 12394, 233, 210, 3015, 19100, 608, 9413, 2668, 16, 18, 16, 19, 16, 20, 16, 1393, 169, 121, 239, }, },
|
||||
|
||||
};
|
||||
|
||||
return _k_tests;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc < 2) {
|
||||
fprintf(stderr, "Usage: %s vocab-file [text-file]\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string fname = argv[1];
|
||||
|
||||
std::string fname_text;
|
||||
if (argc > 2) {
|
||||
fname_text = argv[2];
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
||||
llama_backend_init(false);
|
||||
|
||||
// load the vocab
|
||||
{
|
||||
auto mparams = llama_model_default_params();
|
||||
|
||||
mparams.vocab_only = true;
|
||||
|
||||
model = llama_load_model_from_file(fname.c_str(), mparams);
|
||||
|
||||
if (model == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto cparams = llama_context_default_params();
|
||||
|
||||
ctx = llama_new_context_with_model(model, cparams);
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
|
||||
llama_free_model(model);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_DEEPSEEKCODER) {
|
||||
fprintf(stderr, "%s : error: vocab type is not DEEPSEEKCODER\n", __func__);
|
||||
llama_free_model(model);
|
||||
llama_free(ctx);
|
||||
return 2;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
// We need this for unicode console support
|
||||
console::init(false, false);
|
||||
atexit([]() { console::cleanup(); });
|
||||
#endif
|
||||
|
||||
bool success = true;
|
||||
|
||||
for (const auto & test_kv : k_tests()) {
|
||||
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, false);
|
||||
|
||||
printf("\n");
|
||||
printf("src: '%s'\n", test_kv.first.c_str());
|
||||
printf("res: '%s'\n", llama_detokenize_bpe(ctx, res).c_str());
|
||||
printf("tok: ");
|
||||
for (const auto & tok : res) {
|
||||
printf("%d ", tok);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
bool correct = res.size() == test_kv.second.size();
|
||||
for (int i = 0; i < (int) res.size() && correct; ++i) {
|
||||
if (test_kv.second[i] != res[i]) {
|
||||
correct = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!correct) {
|
||||
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
|
||||
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
|
||||
llama_detokenize_bpe(ctx, res).c_str(),
|
||||
llama_detokenize_bpe(ctx, test_kv.second).c_str());
|
||||
fprintf(stderr, "%s : expected tokens: ", __func__);
|
||||
for (const auto & t : test_kv.second) {
|
||||
fprintf(stderr, "%6d, ", t);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s : got tokens: ", __func__);
|
||||
for (const auto & t : res) {
|
||||
fprintf(stderr, "%6d, ", t);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
success = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!fname_text.empty()) {
|
||||
fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
|
||||
|
||||
std::string text;
|
||||
{
|
||||
std::ifstream ifs(fname_text);
|
||||
if (!ifs) {
|
||||
fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_text.c_str());
|
||||
return 1;
|
||||
}
|
||||
text = std::string(std::istreambuf_iterator<char>(ifs), std::istreambuf_iterator<char>());
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : text size: %zu\n", __func__, text.size());
|
||||
|
||||
const std::vector<llama_token> res = llama_tokenize(ctx, text, false);
|
||||
|
||||
fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size());
|
||||
|
||||
{
|
||||
const std::string fname_out = fname_text + ".tokcpp";
|
||||
|
||||
std::ofstream ofs(fname_out);
|
||||
if (!ofs) {
|
||||
fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (const auto & tok : res) {
|
||||
ofs << tok << " '" << llama_detokenize_bpe(ctx, std::vector<int>{tok}) << "'" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
|
||||
}
|
||||
|
||||
llama_free_model(model);
|
||||
llama_free(ctx);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return success ? 0 : 3;
|
||||
}
|
83
tests/test-tokenizer-0-deepseek-coder.py
Normal file
83
tests/test-tokenizer-0-deepseek-coder.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
# tests with BPE tokenizer
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
|
||||
parser.add_argument("--fname-tok", help="path to a text file to tokenize")
|
||||
args = parser.parse_args()
|
||||
|
||||
dir_tokenizer = args.dir_tokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
|
||||
|
||||
tests = [
|
||||
"",
|
||||
" ",
|
||||
" ",
|
||||
" ",
|
||||
"\t",
|
||||
"\n",
|
||||
"\t\n",
|
||||
"Hello world",
|
||||
" Hello world",
|
||||
"Hello World",
|
||||
" Hello World",
|
||||
" Hello World!",
|
||||
"Hello, world!",
|
||||
" Hello, world!",
|
||||
" this is 🦙.cpp",
|
||||
"w048 7tuijk dsdfhu",
|
||||
"нещо на Български",
|
||||
"កាន់តែពិសេសអាចខលចេញ",
|
||||
"🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)",
|
||||
"Hello",
|
||||
" Hello",
|
||||
" Hello",
|
||||
" Hello",
|
||||
" Hello",
|
||||
" Hello\n Hello",
|
||||
"\n =",
|
||||
"' era",
|
||||
"Hello, y'all! How are you 😁 ?我想在apple工作1314151天~",
|
||||
]
|
||||
|
||||
for text in tests:
|
||||
print('text: ', text)
|
||||
print(tokenizer.encode(text))
|
||||
print(tokenizer.decode(tokenizer.encode(text)))
|
||||
|
||||
print("\n\ntests for C++:\n")
|
||||
for text in tests:
|
||||
res = tokenizer.encode(text)
|
||||
|
||||
k = text.replace('\n', '\\n')
|
||||
k = k.replace('\t', '\\t')
|
||||
k = '"' + k + '"'
|
||||
print("{ %-24s, { " % k, end='')
|
||||
for x in res:
|
||||
print("%7d," % x, end='')
|
||||
print(" }, },")
|
||||
|
||||
print(tokenizer.encode('hello'))
|
||||
print(tokenizer.encode('world'))
|
||||
print(tokenizer.encode(' world'))
|
||||
print(tokenizer.encode('hello world'))
|
||||
|
||||
fname_tok = args.fname_tok
|
||||
if fname_tok:
|
||||
print('tokenizing file: ', fname_tok)
|
||||
fname_out = fname_tok + '.tok'
|
||||
with open(fname_tok, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
s = ''.join(lines)
|
||||
res = tokenizer.encode(s)
|
||||
# write to file
|
||||
with open(fname_out, 'w', encoding='utf-8') as f:
|
||||
for x in res:
|
||||
f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n')
|
||||
print('len(res): ', len(res))
|
||||
print('len(lines): ', len(lines))
|
||||
print('results written to: ', fname_out)
|
188
tests/test-tokenizer-0-deepseek-llm.cpp
Normal file
188
tests/test-tokenizer-0-deepseek-llm.cpp
Normal file
|
@ -0,0 +1,188 @@
|
|||
#include "llama.h"
|
||||
#include "common.h"
|
||||
#include "console.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
||||
// generate using test-tokenizer-0-falcon.py
|
||||
static const std::map<std::string, std::vector<llama_token>> & k_tests() {
|
||||
static std::map<std::string, std::vector<llama_token>> _k_tests = {
|
||||
{ "" , { }, },
|
||||
{ " " , { 207, }, },
|
||||
{ " " , { 243, }, },
|
||||
{ " " , { 300, }, },
|
||||
{ "\t" , { 184, }, },
|
||||
{ "\n" , { 185, }, },
|
||||
{ "\t\n" , { 184, 185, }, },
|
||||
{ "Hello world" , { 17464, 1843, }, },
|
||||
{ " Hello world" , { 37727, 1843, }, },
|
||||
{ "Hello World" , { 17464, 5427, }, },
|
||||
{ " Hello World" , { 37727, 5427, }, },
|
||||
{ " Hello World!" , { 37727, 5427, 0, }, },
|
||||
{ "Hello, world!" , { 17464, 11, 1843, 0, }, },
|
||||
{ " Hello, world!" , { 37727, 11, 1843, 0, }, },
|
||||
{ " this is 🦙.cpp" , { 437, 317, 12356, 99, 234, 13, 14743, }, },
|
||||
{ "w048 7tuijk dsdfhu" , { 86, 15, 19, 23, 207, 22, 83, 3970, 27519, 26016, 3944, 14025, }, },
|
||||
{ "нещо на Български" , { 1603, 6476, 620, 91754, }, },
|
||||
{ "កាន់តែពិសេសអាចខលចេញ" , { 71374, 209, 71374, 114, 71374, 228, 155, 240, 220, 71374, 224, 155, 240, 211, 71374, 231, 71374, 115, 71374, 240, 155, 240, 210, 71374, 240, 71374, 95, 71374, 114, 71374, 214, 71374, 210, 71374, 236, 71374, 214, 155, 240, 210, 71374, 218, }, },
|
||||
{ "🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", { 10044, 95300, 334, 8754, 8, 33701, 114, 350, 222, 10044, 221, 104, 46713, 334, 34732, 996, 24250, 262, 80923, 8, 207, 37103, 214, 334, 5956, 89213, 344, 643, 895, 1377, 10728, 8, }, },
|
||||
{ "Hello" , { 17464, }, },
|
||||
{ " Hello" , { 37727, }, },
|
||||
{ " Hello" , { 207, 37727, }, },
|
||||
{ " Hello" , { 243, 37727, }, },
|
||||
{ " Hello" , { 300, 37727, }, },
|
||||
{ " Hello\n Hello" , { 300, 37727, 185, 300, 37727, }, },
|
||||
{ "\n =" , { 185, 403, }, },
|
||||
{ "' era" , { 6, 2906, }, },
|
||||
{ "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~", { 17464, 11, 320, 6, 436, 0, 1724, 418, 340, 33701, 210, 3025, 19017, 612, 9407, 2681, 16, 18, 16, 19, 16, 20, 16, 1398, 68940, 239, }, },
|
||||
|
||||
};
|
||||
|
||||
return _k_tests;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc < 2) {
|
||||
fprintf(stderr, "Usage: %s vocab-file [text-file]\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string fname = argv[1];
|
||||
|
||||
std::string fname_text;
|
||||
if (argc > 2) {
|
||||
fname_text = argv[2];
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
||||
llama_backend_init(false);
|
||||
|
||||
// load the vocab
|
||||
{
|
||||
auto mparams = llama_model_default_params();
|
||||
|
||||
mparams.vocab_only = true;
|
||||
|
||||
model = llama_load_model_from_file(fname.c_str(), mparams);
|
||||
|
||||
if (model == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto cparams = llama_context_default_params();
|
||||
|
||||
ctx = llama_new_context_with_model(model, cparams);
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
|
||||
llama_free_model(model);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_DEEPSEEKLLM) {
|
||||
fprintf(stderr, "%s : error: vocab type is not DEEPSEEKLLM\n", __func__);
|
||||
llama_free_model(model);
|
||||
llama_free(ctx);
|
||||
return 2;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
// We need this for unicode console support
|
||||
console::init(false, false);
|
||||
atexit([]() { console::cleanup(); });
|
||||
#endif
|
||||
|
||||
bool success = true;
|
||||
|
||||
for (const auto & test_kv : k_tests()) {
|
||||
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, false);
|
||||
|
||||
printf("\n");
|
||||
printf("src: '%s'\n", test_kv.first.c_str());
|
||||
printf("res: '%s'\n", llama_detokenize_bpe(ctx, res).c_str());
|
||||
printf("tok: ");
|
||||
for (const auto & tok : res) {
|
||||
printf("%d ", tok);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
bool correct = res.size() == test_kv.second.size();
|
||||
for (int i = 0; i < (int) res.size() && correct; ++i) {
|
||||
if (test_kv.second[i] != res[i]) {
|
||||
correct = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!correct) {
|
||||
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
|
||||
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
|
||||
llama_detokenize_bpe(ctx, res).c_str(),
|
||||
llama_detokenize_bpe(ctx, test_kv.second).c_str());
|
||||
fprintf(stderr, "%s : expected tokens: ", __func__);
|
||||
for (const auto & t : test_kv.second) {
|
||||
fprintf(stderr, "%6d, ", t);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s : got tokens: ", __func__);
|
||||
for (const auto & t : res) {
|
||||
fprintf(stderr, "%6d, ", t);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
success = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!fname_text.empty()) {
|
||||
fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
|
||||
|
||||
std::string text;
|
||||
{
|
||||
std::ifstream ifs(fname_text);
|
||||
if (!ifs) {
|
||||
fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_text.c_str());
|
||||
return 1;
|
||||
}
|
||||
text = std::string(std::istreambuf_iterator<char>(ifs), std::istreambuf_iterator<char>());
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : text size: %zu\n", __func__, text.size());
|
||||
|
||||
const std::vector<llama_token> res = llama_tokenize(ctx, text, false);
|
||||
|
||||
fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size());
|
||||
|
||||
{
|
||||
const std::string fname_out = fname_text + ".tokcpp";
|
||||
|
||||
std::ofstream ofs(fname_out);
|
||||
if (!ofs) {
|
||||
fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (const auto & tok : res) {
|
||||
ofs << tok << " '" << llama_detokenize_bpe(ctx, std::vector<int>{tok}) << "'" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
|
||||
}
|
||||
|
||||
llama_free_model(model);
|
||||
llama_free(ctx);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return success ? 0 : 3;
|
||||
}
|
83
tests/test-tokenizer-0-deepseek-llm.py
Normal file
83
tests/test-tokenizer-0-deepseek-llm.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
# tests with BPE tokenizer
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
|
||||
parser.add_argument("--fname-tok", help="path to a text file to tokenize")
|
||||
args = parser.parse_args()
|
||||
|
||||
dir_tokenizer = args.dir_tokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
|
||||
|
||||
tests = [
|
||||
"",
|
||||
" ",
|
||||
" ",
|
||||
" ",
|
||||
"\t",
|
||||
"\n",
|
||||
"\t\n",
|
||||
"Hello world",
|
||||
" Hello world",
|
||||
"Hello World",
|
||||
" Hello World",
|
||||
" Hello World!",
|
||||
"Hello, world!",
|
||||
" Hello, world!",
|
||||
" this is 🦙.cpp",
|
||||
"w048 7tuijk dsdfhu",
|
||||
"нещо на Български",
|
||||
"កាន់តែពិសេសអាចខលចេញ",
|
||||
"🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)",
|
||||
"Hello",
|
||||
" Hello",
|
||||
" Hello",
|
||||
" Hello",
|
||||
" Hello",
|
||||
" Hello\n Hello",
|
||||
"\n =",
|
||||
"' era",
|
||||
"Hello, y'all! How are you 😁 ?我想在apple工作1314151天~",
|
||||
]
|
||||
|
||||
for text in tests:
|
||||
print('text: ', text)
|
||||
print(tokenizer.encode(text))
|
||||
print(tokenizer.decode(tokenizer.encode(text)))
|
||||
|
||||
print("\n\ntests for C++:\n")
|
||||
for text in tests:
|
||||
res = tokenizer.encode(text)
|
||||
|
||||
k = text.replace('\n', '\\n')
|
||||
k = k.replace('\t', '\\t')
|
||||
k = '"' + k + '"'
|
||||
print("{ %-24s, { " % k, end='')
|
||||
for x in res:
|
||||
print("%7d," % x, end='')
|
||||
print(" }, },")
|
||||
|
||||
print(tokenizer.encode('hello'))
|
||||
print(tokenizer.encode('world'))
|
||||
print(tokenizer.encode(' world'))
|
||||
print(tokenizer.encode('hello world'))
|
||||
|
||||
fname_tok = args.fname_tok
|
||||
if fname_tok:
|
||||
print('tokenizing file: ', fname_tok)
|
||||
fname_out = fname_tok + '.tok'
|
||||
with open(fname_tok, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
s = ''.join(lines)
|
||||
res = tokenizer.encode(s)
|
||||
# write to file
|
||||
with open(fname_out, 'w', encoding='utf-8') as f:
|
||||
for x in res:
|
||||
f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n')
|
||||
print('len(res): ', len(res))
|
||||
print('len(lines): ', len(lines))
|
||||
print('results written to: ', fname_out)
|
|
@ -38,6 +38,7 @@ static const std::map<std::string, std::vector<llama_token>> & k_tests() {
|
|||
{ " Hello\n Hello" , { 466, 23090, 742, 23090, }, },
|
||||
{ "\n =" , { 1212, 40, }, },
|
||||
{ "' era" , { 18, 4932, }, },
|
||||
{ "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~", { 9856, 23, 291, 18, 436, 12, 1265, 362, 299, 8196, 207, 204, 42, 50087, 123, 2727, 20300, 32022, 133, 234, 17419, 30137, 28, 7858, 181, 133, 236, }, },
|
||||
};
|
||||
|
||||
return _k_tests;
|
||||
|
@ -115,7 +116,6 @@ int main(int argc, char **argv) {
|
|||
printf("\n");
|
||||
|
||||
bool correct = res.size() == test_kv.second.size();
|
||||
|
||||
for (int i = 0; i < (int) res.size() && correct; ++i) {
|
||||
if (test_kv.second[i] != res[i]) {
|
||||
correct = false;
|
||||
|
|
|
@ -41,6 +41,7 @@ tests = [
|
|||
" Hello\n Hello",
|
||||
"\n =",
|
||||
"' era",
|
||||
"Hello, y'all! How are you 😁 ?我想在apple工作1314151天~",
|
||||
]
|
||||
|
||||
for text in tests:
|
||||
|
|
47
unicode.h
47
unicode.h
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue