Compare commits

...
Sign in to create a new pull request.

10 commits

Author SHA1 Message Date
Georgi Gerganov
703764a382
convert : use non-fast T5 tokenizer 2024-07-02 19:29:26 +03:00
Georgi Gerganov
17bb0eaec3
llama : UGM tokenizer init with UNK tokens instead of PAD 2024-07-02 10:40:14 +03:00
Georgi Gerganov
9eb5d5617d
convert : add t5 tokenizer tests 2024-07-02 10:39:49 +03:00
Stanisław Szymczyk
6dc9eb4040 llama : quantization-related fixes for T5 2024-06-29 18:09:22 +02:00
Stanisław Szymczyk
7d7fff4654 llama : whitespace formatting 2024-06-27 10:13:53 +02:00
Stanisław Szymczyk
7293243d4f Merge remote-tracking branch 'upstream/master' into t5-clean-3 2024-06-27 09:29:26 +02:00
Stanisław Szymczyk
c4ded1a8fb llama : make pos_bias contiguous for CUDA 2024-06-26 17:46:39 +02:00
Stanisław Szymczyk
bad0cafee9 llama : updated llm_build_ffn() calls to new API in build_t5() 2024-06-26 17:38:13 +02:00
fairydreaming
1c8d37a267
Merge branch 'ggerganov:master' into t5-clean-3 2024-06-26 17:31:15 +02:00
Stanisław Szymczyk
45681a57dd llama : add inference support and model types for T5 and FLAN-T5 model families
llama : add new API functions to support encoder-decoder models: llama_encode(), llama_model_has_encoder(), llama_model_decoder_start_token()

common, llama-cli : use new API functions to support encoder-decoder models

convert-hf : handle shared token embeddings tensors in T5Model

convert-hf : handle SentencePiece BPE tokenizer in T5Model (for Pile-T5 models)

convert-hf : add MT5ForConditionalGeneration and UMT5ForConditionalGeneration to architectures supported by T5Model
2024-06-26 15:03:01 +02:00
32 changed files with 955 additions and 27 deletions

View file

@ -2061,7 +2061,24 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
if (params.warmup) { if (params.warmup) {
LOG("warming up the model with an empty run\n"); LOG("warming up the model with an empty run\n");
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), }; std::vector<llama_token> tmp;
llama_token bos = llama_token_bos(model);
llama_token eos = llama_token_eos(model);
// some models (e.g. T5) don't have a BOS token
if (bos != -1) {
tmp.push_back(bos);
}
tmp.push_back(eos);
if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
decoder_start_token_id = bos;
}
tmp.clear();
tmp.push_back(decoder_start_token_id);
}
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_clear(lctx); llama_kv_cache_clear(lctx);
llama_synchronize(lctx); llama_synchronize(lctx);

View file

@ -45,6 +45,7 @@ class TOKENIZER_TYPE(IntEnum):
SPM = auto() SPM = auto()
BPE = auto() BPE = auto()
WPM = auto() WPM = auto()
UGM = auto()
# TODO: this string has to exercise as much pre-tokenizer functionality as possible # TODO: this string has to exercise as much pre-tokenizer functionality as possible
@ -85,6 +86,7 @@ models = [
{"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", }, {"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", },
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", }, {"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", }, {"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
] ]
@ -106,9 +108,13 @@ def download_model(model):
os.makedirs(f"models/tokenizers/{name}", exist_ok=True) os.makedirs(f"models/tokenizers/{name}", exist_ok=True)
files = ["config.json", "tokenizer.json", "tokenizer_config.json"] files = ["config.json", "tokenizer.json", "tokenizer_config.json"]
if tokt == TOKENIZER_TYPE.SPM: if tokt == TOKENIZER_TYPE.SPM:
files.append("tokenizer.model") files.append("tokenizer.model")
if tokt == TOKENIZER_TYPE.UGM:
files.append("spiece.model")
for file in files: for file in files:
save_path = f"models/tokenizers/{name}/{file}" save_path = f"models/tokenizers/{name}/{file}"
if os.path.isfile(save_path): if os.path.isfile(save_path):
@ -131,7 +137,7 @@ for model in models:
name = model["name"] name = model["name"]
tokt = model["tokt"] tokt = model["tokt"]
if tokt == TOKENIZER_TYPE.SPM: if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
continue continue
# Skip if the tokenizer folder does not exist or there are other download issues previously # Skip if the tokenizer folder does not exist or there are other download issues previously
@ -141,7 +147,10 @@ for model in models:
# create the tokenizer # create the tokenizer
try: try:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") if name == "t5":
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
else:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
except OSError as e: except OSError as e:
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
continue # Skip to the next model if the tokenizer can't be loaded continue # Skip to the next model if the tokenizer can't be loaded
@ -262,6 +271,7 @@ tests = [
"\n =", "\n =",
"' era", "' era",
"Hello, y'all! How are you 😁 ?我想在apple工作1314151天", "Hello, y'all! How are you 😁 ?我想在apple工作1314151天",
"!!!!!!",
"3", "3",
"33", "33",
"333", "333",
@ -299,7 +309,10 @@ for model in models:
# create the tokenizer # create the tokenizer
try: try:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") if name == "t5":
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
else:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
except OSError as e: except OSError as e:
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}") logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
continue # Skip this model and continue with the next one in the loop continue # Skip this model and continue with the next one in the loop

View file

@ -2775,11 +2775,17 @@ class DeepseekV2Model(Model):
raise ValueError(f"Unprocessed experts: {experts}") raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("T5ForConditionalGeneration")
@Model.register("T5WithLMHeadModel") @Model.register("T5WithLMHeadModel")
@Model.register("T5ForConditionalGeneration")
@Model.register("MT5ForConditionalGeneration")
@Model.register("UMT5ForConditionalGeneration")
class T5Model(Model): class T5Model(Model):
model_arch = gguf.MODEL_ARCH.T5 model_arch = gguf.MODEL_ARCH.T5
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.shared_token_embeddings_found = False
def set_vocab(self): def set_vocab(self):
# to avoid TypeError: Descriptors cannot be created directly # to avoid TypeError: Descriptors cannot be created directly
# exception when importing sentencepiece_model_pb2 # exception when importing sentencepiece_model_pb2
@ -2787,17 +2793,29 @@ class T5Model(Model):
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
from sentencepiece import sentencepiece_model_pb2 as model from sentencepiece import sentencepiece_model_pb2 as model
tokenizer_path = self.dir_model / 'spiece.model' tokenizer_path = self.dir_model / 'tokenizer.model'
# many older models use spiece.model tokenizer model filename
if not tokenizer_path.is_file():
tokenizer_path = self.dir_model / 'spiece.model'
if not tokenizer_path.is_file(): if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}") raise FileNotFoundError(f"File not found: {tokenizer_path}")
sentencepiece_model = model.ModelProto() sentencepiece_model = model.ModelProto()
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
if sentencepiece_model.trainer_spec.model_type == 2: # BPE
# assure the tokenizer model file name is correct
assert tokenizer_path.name == 'tokenizer.model'
return self._set_vocab_sentencepiece()
else:
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
tokenizer = SentencePieceProcessor() tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path)) tokenizer.LoadFromFile(str(tokenizer_path))
@ -2867,7 +2885,10 @@ class T5Model(Model):
def set_gguf_parameters(self): def set_gguf_parameters(self):
self.gguf_writer.add_name("T5") self.gguf_writer.add_name("T5")
self.gguf_writer.add_context_length(self.hparams["n_positions"]) if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
n_ctx = 512
self.gguf_writer.add_context_length(n_ctx)
self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"]) self.gguf_writer.add_block_count(self.hparams["num_layers"])
@ -2883,12 +2904,17 @@ class T5Model(Model):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused del bid # unused
# Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
# "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
# To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight". # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight": # and decoder and ignore the remaining ones.
logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.") if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
return [] if not self.shared_token_embeddings_found:
name = "shared.weight"
self.shared_token_embeddings_found = True
else:
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
return []
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]

View file

@ -255,7 +255,9 @@ int main(int argc, char ** argv) {
} }
const bool add_bos = llama_should_add_bos_token(model); const bool add_bos = llama_should_add_bos_token(model);
GGML_ASSERT(llama_add_eos_token(model) != 1); if (!llama_model_has_encoder(model)) {
GGML_ASSERT(llama_add_eos_token(model) != 1);
}
LOG("add_bos: %d\n", add_bos); LOG("add_bos: %d\n", add_bos);
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
@ -517,6 +519,23 @@ int main(int argc, char ** argv) {
exit(1); exit(1);
} }
if (llama_model_has_encoder(model)) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
decoder_start_token_id = llama_token_bos(model);
}
embd_inp.clear();
embd_inp.push_back(decoder_start_token_id);
}
while ((n_remain != 0 && !is_antiprompt) || params.interactive) { while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict // predict
if (!embd.empty()) { if (!embd.empty()) {

View file

@ -483,6 +483,13 @@ extern "C" {
// Get a llama model tensor // Get a llama model tensor
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
// Returns true if the model contains an encoder that requires llama_encode() call
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
// For encoder-decoder models, this function returns id of the token that must be provided
// to the decoder to start generating output sequence. For other models, it returns -1.
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
// Returns 0 on success // Returns 0 on success
LLAMA_API uint32_t llama_model_quantize( LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp, const char * fname_inp,
@ -768,6 +775,14 @@ extern "C" {
// Frees a batch of tokens allocated with llama_batch_init() // Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch); LLAMA_API void llama_batch_free(struct llama_batch batch);
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// Stores the encoder output internally for later use by the decoder cross-attention layers.
// 0 - success
// < 0 - error
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);
// Positive return values does not mean a fatal error, but rather a warning. // Positive return values does not mean a fatal error, but rather a warning.
// 0 - success // 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33

View file

@ -31,6 +31,7 @@
1027 1027
1005 3690 1005 3690
7592 1010 1061 1005 2035 999 2129 2024 2017 100 1029 1855 100 100 6207 100 100 14677 23632 22203 1811 1995 7592 1010 1061 1005 2035 999 2129 2024 2017 100 1029 1855 100 100 6207 100 100 14677 23632 22203 1811 1995
999 999 999 999 999 999
1017 1017
3943 3943
21211 21211

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33

View file

@ -31,6 +31,7 @@
206 1857 206 1857
14 4515 14 4515
28339 19 1770 14 1954 8 4070 1955 1933 80503 231 5691 12081 13336 2648 29325 14315 24 26 24 27 24 28 24 5123 18372 28339 19 1770 14 1954 8 4070 1955 1933 80503 231 5691 12081 13336 2648 29325 14315 24 26 24 27 24 28 24 5123 18372
57178 10251
26 26
26 26 26 26
26 26 26 26 26 26

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33

View file

@ -31,6 +31,7 @@
185 405 185 405
6 2895 6 2895
17535 11 320 6 435 0 1717 417 340 12394 233 210 3015 19100 608 9413 2668 16 18 16 19 16 20 16 1393 169 121 239 17535 11 320 6 435 0 1717 417 340 12394 233 210 3015 19100 608 9413 2668 16 18 16 19 16 20 16 1393 169 121 239
15330 3023
18 18
18 18 18 18
18 18 18 18 18 18

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33

View file

@ -31,6 +31,7 @@
185 403 185 403
6 2906 6 2906
17464 11 320 6 436 0 1724 418 340 33701 210 3025 19017 612 9407 2681 16 18 16 19 16 20 16 1398 68940 239 17464 11 320 6 436 0 1724 418 340 33701 210 3025 19017 612 9407 2681 16 18 16 19 16 20 16 1398 68940 239
15278 3033
18 18
18 18 18 18
18 18 18 18 18 18

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33

View file

@ -31,6 +31,7 @@
1212 40 1212 40
18 4932 18 4932
9856 23 291 18 436 12 1265 362 299 8196 207 204 42 50087 123 2727 20300 32022 133 234 17419 30137 28 7858 181 133 236 9856 23 291 18 436 12 1265 362 299 8196 207 204 42 50087 123 2727 20300 32022 133 234 17419 30137 28 7858 181 133 236
51520
30 30
3138 3138
22287 22287

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33

View file

@ -31,6 +31,7 @@
198 796 198 796
6 6980 6 6980
15496 11 331 6 439 0 1374 389 345 30325 223 5633 22755 239 46349 111 28839 101 18040 32432 98 43291 1485 1415 24309 25465 171 121 252 15496 11 331 6 439 0 1374 389 345 30325 223 5633 22755 239 46349 111 28839 101 18040 32432 98 43291 1485 1415 24309 25465 171 121 252
13896 3228
18 18
2091 2091
20370 20370

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33
@ -104,5 +106,3 @@ __ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天 ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL 🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天 ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
__ggml_vocab_test__ __ggml_vocab_test__
Việt
__ggml_vocab_test__

View file

@ -31,6 +31,7 @@
198 284 198 284
6 11639 6 11639
9906 11 379 65948 0 2650 527 499 27623 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909 9906 11 379 65948 0 2650 527 499 27623 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909
17523 3001
18 18
1644 1644
8765 8765
@ -41,4 +42,3 @@
8765 8765 1644 8765 8765 1644
8765 8765 8765 8765 8765 8765
198 4815 15073 66597 8004 1602 2355 79772 11187 9468 248 222 320 8416 8 27623 114 102470 9468 234 104 31643 320 36773 100166 98634 8 26602 227 11410 99 247 9468 99 247 220 18 220 1644 220 8765 220 8765 18 220 8765 1644 220 8765 8765 220 8765 8765 18 220 8765 8765 1644 220 18 13 18 220 18 497 18 220 18 1131 18 220 21549 222 98629 241 45358 233 21549 237 45358 224 21549 244 21549 115 21549 253 45358 223 21549 253 21549 95 98629 227 76460 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909 56560 54337 19175 102118 13373 64571 34694 3114 112203 80112 3436 106451 14196 14196 74694 3089 3089 29249 17523 3001 27708 7801 358 3077 1027 364 83 820 568 596 1070 11 364 793 499 2771 30 364 44 539 2771 358 3358 1304 433 11 364 35 499 1093 1063 15600 30 1226 6 43712 264 64966 43 198 4815 15073 66597 8004 1602 2355 79772 11187 9468 248 222 320 8416 8 27623 114 102470 9468 234 104 31643 320 36773 100166 98634 8 26602 227 11410 99 247 9468 99 247 220 18 220 1644 220 8765 220 8765 18 220 8765 1644 220 8765 8765 220 8765 8765 18 220 8765 8765 1644 220 18 13 18 220 18 497 18 220 18 1131 18 220 21549 222 98629 241 45358 233 21549 237 45358 224 21549 244 21549 115 21549 253 45358 223 21549 253 21549 95 98629 227 76460 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909 56560 54337 19175 102118 13373 64571 34694 3114 112203 80112 3436 106451 14196 14196 74694 3089 3089 29249 17523 3001 27708 7801 358 3077 1027 364 83 820 568 596 1070 11 364 793 499 2771 30 364 44 539 2771 358 3358 1304 433 11 364 35 499 1093 1063 15600 30 1226 6 43712 264 64966 43
101798

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33

View file

@ -31,6 +31,7 @@
29871 13 353 29871 13 353
525 3152 525 3152
15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739 15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739
1738 6824 21004
29871 29941 29871 29941
29871 29941 29941 29871 29941 29941
29871 29941 29941 29941 29871 29941 29941 29941

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33

View file

@ -31,6 +31,7 @@
187 426 187 426
8 8685 8 8685
12092 13 340 8 455 2 1359 403 368 49042 212 3736 15367 41197 13610 19934 41869 21275 1012 1047 18795 40120 20422 241 12092 13 340 8 455 2 1359 403 368 49042 212 3736 15367 41197 13610 19934 41869 21275 1012 1047 18795 40120 20422 241
18963 4672
20 20
1610 1610
20084 20084

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33

View file

@ -31,6 +31,7 @@
29871 13 353 29871 13 353
525 3152 525 3152
15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739 15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739
1738 6824 21004
29871 29941 29871 29941
29871 29941 29941 29871 29941 29941
29871 29941 29941 29941 29871 29941 29941 29941

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33

View file

@ -31,6 +31,7 @@
198 284 198 284
6 11385 6 11385
9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216
17085 2928
18 18
18 18 18 18
18 18 18 18 18 18

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33

View file

@ -31,6 +31,7 @@
203 280 203 280
25 34666 25 34666
8279 30 533 25 464 19 4971 884 844 18458 228 1018 4982 13368 2909 9513 17827 35 37 35 38 35 39 35 11873 47838 8279 30 533 25 464 19 4971 884 844 18458 228 1018 4982 13368 2909 9513 17827 35 37 35 38 35 39 35 11873 47838
9163 3202
37 37
37 37 37 37
37 37 37 37 37 37

View file

@ -73,6 +73,8 @@ __ggml_vocab_test__
__ggml_vocab_test__ __ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天 Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__ __ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3 3
__ggml_vocab_test__ __ggml_vocab_test__
33 33

View file

@ -31,6 +31,7 @@
222 299 222 299
44 34719 44 34719
8302 49 553 44 483 38 4998 904 863 18445 247 1037 4995 13379 2924 9515 17823 54 56 54 57 54 58 54 11904 47892 8302 49 553 44 483 38 4998 904 863 18445 247 1037 4995 13379 2924 9515 17823 54 56 54 57 54 58 54 11904 47892
9221 3226
56 56
56 56 56 56
56 56 56 56 56 56

File diff suppressed because it is too large Load diff