Compare commits
10 commits
master
...
fairydream
Author | SHA1 | Date | |
---|---|---|---|
|
703764a382 | ||
|
17bb0eaec3 | ||
|
9eb5d5617d | ||
|
6dc9eb4040 | ||
|
7d7fff4654 | ||
|
7293243d4f | ||
|
c4ded1a8fb | ||
|
bad0cafee9 | ||
|
1c8d37a267 | ||
|
45681a57dd |
32 changed files with 955 additions and 27 deletions
|
@ -2061,7 +2061,24 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||
if (params.warmup) {
|
||||
LOG("warming up the model with an empty run\n");
|
||||
|
||||
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
|
||||
std::vector<llama_token> tmp;
|
||||
llama_token bos = llama_token_bos(model);
|
||||
llama_token eos = llama_token_eos(model);
|
||||
// some models (e.g. T5) don't have a BOS token
|
||||
if (bos != -1) {
|
||||
tmp.push_back(bos);
|
||||
}
|
||||
tmp.push_back(eos);
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
|
||||
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
||||
if (decoder_start_token_id == -1) {
|
||||
decoder_start_token_id = bos;
|
||||
}
|
||||
tmp.clear();
|
||||
tmp.push_back(decoder_start_token_id);
|
||||
}
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
||||
llama_kv_cache_clear(lctx);
|
||||
llama_synchronize(lctx);
|
||||
|
|
|
@ -45,6 +45,7 @@ class TOKENIZER_TYPE(IntEnum):
|
|||
SPM = auto()
|
||||
BPE = auto()
|
||||
WPM = auto()
|
||||
UGM = auto()
|
||||
|
||||
|
||||
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
|
||||
|
@ -85,6 +86,7 @@ models = [
|
|||
{"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", },
|
||||
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
|
||||
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
|
||||
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
|
||||
]
|
||||
|
||||
|
||||
|
@ -106,9 +108,13 @@ def download_model(model):
|
|||
os.makedirs(f"models/tokenizers/{name}", exist_ok=True)
|
||||
|
||||
files = ["config.json", "tokenizer.json", "tokenizer_config.json"]
|
||||
|
||||
if tokt == TOKENIZER_TYPE.SPM:
|
||||
files.append("tokenizer.model")
|
||||
|
||||
if tokt == TOKENIZER_TYPE.UGM:
|
||||
files.append("spiece.model")
|
||||
|
||||
for file in files:
|
||||
save_path = f"models/tokenizers/{name}/{file}"
|
||||
if os.path.isfile(save_path):
|
||||
|
@ -131,7 +137,7 @@ for model in models:
|
|||
name = model["name"]
|
||||
tokt = model["tokt"]
|
||||
|
||||
if tokt == TOKENIZER_TYPE.SPM:
|
||||
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
|
||||
continue
|
||||
|
||||
# Skip if the tokenizer folder does not exist or there are other download issues previously
|
||||
|
@ -141,7 +147,10 @@ for model in models:
|
|||
|
||||
# create the tokenizer
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||
if name == "t5":
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||
except OSError as e:
|
||||
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
|
||||
continue # Skip to the next model if the tokenizer can't be loaded
|
||||
|
@ -262,6 +271,7 @@ tests = [
|
|||
"\n =",
|
||||
"' era",
|
||||
"Hello, y'all! How are you 😁 ?我想在apple工作1314151天~",
|
||||
"!!!!!!",
|
||||
"3",
|
||||
"33",
|
||||
"333",
|
||||
|
@ -299,7 +309,10 @@ for model in models:
|
|||
|
||||
# create the tokenizer
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||
if name == "t5":
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
|
||||
continue # Skip this model and continue with the next one in the loop
|
||||
|
|
|
@ -2775,11 +2775,17 @@ class DeepseekV2Model(Model):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("T5ForConditionalGeneration")
|
||||
@Model.register("T5WithLMHeadModel")
|
||||
@Model.register("T5ForConditionalGeneration")
|
||||
@Model.register("MT5ForConditionalGeneration")
|
||||
@Model.register("UMT5ForConditionalGeneration")
|
||||
class T5Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.T5
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.shared_token_embeddings_found = False
|
||||
|
||||
def set_vocab(self):
|
||||
# to avoid TypeError: Descriptors cannot be created directly
|
||||
# exception when importing sentencepiece_model_pb2
|
||||
|
@ -2787,17 +2793,29 @@ class T5Model(Model):
|
|||
from sentencepiece import SentencePieceProcessor
|
||||
from sentencepiece import sentencepiece_model_pb2 as model
|
||||
|
||||
tokenizer_path = self.dir_model / 'spiece.model'
|
||||
tokenizer_path = self.dir_model / 'tokenizer.model'
|
||||
|
||||
# many older models use spiece.model tokenizer model filename
|
||||
if not tokenizer_path.is_file():
|
||||
tokenizer_path = self.dir_model / 'spiece.model'
|
||||
|
||||
if not tokenizer_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {tokenizer_path}")
|
||||
|
||||
sentencepiece_model = model.ModelProto()
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
|
||||
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
|
||||
if sentencepiece_model.trainer_spec.model_type == 2: # BPE
|
||||
# assure the tokenizer model file name is correct
|
||||
assert tokenizer_path.name == 'tokenizer.model'
|
||||
return self._set_vocab_sentencepiece()
|
||||
else:
|
||||
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
|
||||
|
||||
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
|
||||
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
|
||||
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
|
||||
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
|
||||
|
||||
tokenizer = SentencePieceProcessor()
|
||||
tokenizer.LoadFromFile(str(tokenizer_path))
|
||||
|
@ -2867,7 +2885,10 @@ class T5Model(Model):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_name("T5")
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
|
||||
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
|
||||
n_ctx = 512
|
||||
self.gguf_writer.add_context_length(n_ctx)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
|
@ -2883,12 +2904,17 @@ class T5Model(Model):
|
|||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
# Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or
|
||||
# "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor
|
||||
# To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight".
|
||||
if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight":
|
||||
logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.")
|
||||
return []
|
||||
# T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
|
||||
# "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
|
||||
# in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
|
||||
# and decoder and ignore the remaining ones.
|
||||
if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
|
||||
if not self.shared_token_embeddings_found:
|
||||
name = "shared.weight"
|
||||
self.shared_token_embeddings_found = True
|
||||
else:
|
||||
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
|
|
@ -255,7 +255,9 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
const bool add_bos = llama_should_add_bos_token(model);
|
||||
GGML_ASSERT(llama_add_eos_token(model) != 1);
|
||||
if (!llama_model_has_encoder(model)) {
|
||||
GGML_ASSERT(llama_add_eos_token(model) != 1);
|
||||
}
|
||||
LOG("add_bos: %d\n", add_bos);
|
||||
|
||||
std::vector<llama_token> embd_inp;
|
||||
|
@ -517,6 +519,23 @@ int main(int argc, char ** argv) {
|
|||
exit(1);
|
||||
}
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
int enc_input_size = embd_inp.size();
|
||||
llama_token * enc_input_buf = embd_inp.data();
|
||||
|
||||
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
||||
if (decoder_start_token_id == -1) {
|
||||
decoder_start_token_id = llama_token_bos(model);
|
||||
}
|
||||
embd_inp.clear();
|
||||
embd_inp.push_back(decoder_start_token_id);
|
||||
}
|
||||
|
||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||
// predict
|
||||
if (!embd.empty()) {
|
||||
|
|
|
@ -483,6 +483,13 @@ extern "C" {
|
|||
// Get a llama model tensor
|
||||
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
|
||||
|
||||
// Returns true if the model contains an encoder that requires llama_encode() call
|
||||
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
|
||||
|
||||
// For encoder-decoder models, this function returns id of the token that must be provided
|
||||
// to the decoder to start generating output sequence. For other models, it returns -1.
|
||||
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
|
||||
|
||||
// Returns 0 on success
|
||||
LLAMA_API uint32_t llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
|
@ -768,6 +775,14 @@ extern "C" {
|
|||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||
|
||||
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
|
||||
// Stores the encoder output internally for later use by the decoder cross-attention layers.
|
||||
// 0 - success
|
||||
// < 0 - error
|
||||
LLAMA_API int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
1027
|
||||
1005 3690
|
||||
7592 1010 1061 1005 2035 999 2129 2024 2017 100 1029 1855 100 100 6207 100 100 14677 23632 22203 1811 1995
|
||||
999 999 999 999 999 999
|
||||
1017
|
||||
3943
|
||||
21211
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
206 1857
|
||||
14 4515
|
||||
28339 19 1770 14 1954 8 4070 1955 1933 80503 231 5691 12081 13336 2648 29325 14315 24 26 24 27 24 28 24 5123 18372
|
||||
57178 10251
|
||||
26
|
||||
26 26
|
||||
26 26 26
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
185 405
|
||||
6 2895
|
||||
17535 11 320 6 435 0 1717 417 340 12394 233 210 3015 19100 608 9413 2668 16 18 16 19 16 20 16 1393 169 121 239
|
||||
15330 3023
|
||||
18
|
||||
18 18
|
||||
18 18 18
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
185 403
|
||||
6 2906
|
||||
17464 11 320 6 436 0 1724 418 340 33701 210 3025 19017 612 9407 2681 16 18 16 19 16 20 16 1398 68940 239
|
||||
15278 3033
|
||||
18
|
||||
18 18
|
||||
18 18 18
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
1212 40
|
||||
18 4932
|
||||
9856 23 291 18 436 12 1265 362 299 8196 207 204 42 50087 123 2727 20300 32022 133 234 17419 30137 28 7858 181 133 236
|
||||
51520
|
||||
30
|
||||
3138
|
||||
22287
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
198 796
|
||||
6 6980
|
||||
15496 11 331 6 439 0 1374 389 345 30325 223 5633 22755 239 46349 111 28839 101 18040 32432 98 43291 1485 1415 24309 25465 171 121 252
|
||||
13896 3228
|
||||
18
|
||||
2091
|
||||
20370
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
@ -104,5 +106,3 @@ __ggml_vocab_test__
|
|||
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
|
||||
__ggml_vocab_test__
|
||||
Việt
|
||||
__ggml_vocab_test__
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
198 284
|
||||
6 11639
|
||||
9906 11 379 65948 0 2650 527 499 27623 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909
|
||||
17523 3001
|
||||
18
|
||||
1644
|
||||
8765
|
||||
|
@ -41,4 +42,3 @@
|
|||
8765 8765 1644
|
||||
8765 8765 8765
|
||||
198 4815 15073 66597 8004 1602 2355 79772 11187 9468 248 222 320 8416 8 27623 114 102470 9468 234 104 31643 320 36773 100166 98634 8 26602 227 11410 99 247 9468 99 247 220 18 220 1644 220 8765 220 8765 18 220 8765 1644 220 8765 8765 220 8765 8765 18 220 8765 8765 1644 220 18 13 18 220 18 497 18 220 18 1131 18 220 21549 222 98629 241 45358 233 21549 237 45358 224 21549 244 21549 115 21549 253 45358 223 21549 253 21549 95 98629 227 76460 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909 56560 54337 19175 102118 13373 64571 34694 3114 112203 80112 3436 106451 14196 14196 74694 3089 3089 29249 17523 3001 27708 7801 358 3077 1027 364 83 820 568 596 1070 11 364 793 499 2771 30 364 44 539 2771 358 3358 1304 433 11 364 35 499 1093 1063 15600 30 1226 6 43712 264 64966 43
|
||||
101798
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
29871 13 353
|
||||
525 3152
|
||||
15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739
|
||||
1738 6824 21004
|
||||
29871 29941
|
||||
29871 29941 29941
|
||||
29871 29941 29941 29941
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
187 426
|
||||
8 8685
|
||||
12092 13 340 8 455 2 1359 403 368 49042 212 3736 15367 41197 13610 19934 41869 21275 1012 1047 18795 40120 20422 241
|
||||
18963 4672
|
||||
20
|
||||
1610
|
||||
20084
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
29871 13 353
|
||||
525 3152
|
||||
15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739
|
||||
1738 6824 21004
|
||||
29871 29941
|
||||
29871 29941 29941
|
||||
29871 29941 29941 29941
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
198 284
|
||||
6 11385
|
||||
9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216
|
||||
17085 2928
|
||||
18
|
||||
18 18
|
||||
18 18 18
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
203 280
|
||||
25 34666
|
||||
8279 30 533 25 464 19 4971 884 844 18458 228 1018 4982 13368 2909 9513 17827 35 37 35 38 35 39 35 11873 47838
|
||||
9163 3202
|
||||
37
|
||||
37 37
|
||||
37 37 37
|
||||
|
|
|
@ -73,6 +73,8 @@ __ggml_vocab_test__
|
|||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
222 299
|
||||
44 34719
|
||||
8302 49 553 44 483 38 4998 904 863 18445 247 1037 4995 13379 2924 9515 17823 54 56 54 57 54 58 54 11904 47892
|
||||
9221 3226
|
||||
56
|
||||
56 56
|
||||
56 56 56
|
||||
|
|
820
src/llama.cpp
820
src/llama.cpp
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue