Style
This commit is contained in:
parent
edf375d26f
commit
a5fa2fec60
1 changed files with 29 additions and 28 deletions
|
@ -21,7 +21,7 @@ class LibLlama:
|
|||
DEFAULT_PATH_LLAMA_H = "./llama.h"
|
||||
DEFAULT_PATH_LIBLLAMA = "./build/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON
|
||||
|
||||
def __init__(self, path_llama_h:str=None, path_libllama:str=None):
|
||||
def __init__(self, path_llama_h: str = None, path_libllama: str = None):
|
||||
path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
|
||||
path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
|
||||
(self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_libllama)
|
||||
|
@ -55,21 +55,22 @@ class LibLlama:
|
|||
setattr(cparams, k, v)
|
||||
return cparams
|
||||
|
||||
|
||||
class LibLlamaModel:
|
||||
|
||||
def __init__(self, libllama:LibLlama, path_model:str, mparams={}, cparams={}):
|
||||
def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
|
||||
self.lib = libllama.lib
|
||||
self.ffi = libllama.ffi
|
||||
if type(mparams) == dict:
|
||||
mparams = libllama.model_default_params(**mparams)
|
||||
self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams)
|
||||
if not self.model:
|
||||
raise RuntimeError("error: failed to load model '%s'"%path_model)
|
||||
raise RuntimeError("error: failed to load model '%s'" % path_model)
|
||||
if type(cparams) == dict:
|
||||
cparams = libllama.context_default_params(**cparams)
|
||||
self.ctx = self.lib.llama_new_context_with_model(self.model, cparams)
|
||||
if not self.ctx:
|
||||
raise RuntimeError("error: failed to create context for model '%s'"%path_model)
|
||||
raise RuntimeError("error: failed to create context for model '%s'" % path_model)
|
||||
n_tokens_max = self.lib.llama_n_ctx(self.ctx)
|
||||
self.token_ids = self.ffi.new("llama_token[]", n_tokens_max)
|
||||
|
||||
|
@ -82,7 +83,7 @@ class LibLlamaModel:
|
|||
self.model = None
|
||||
self.lib = None
|
||||
|
||||
def tokenize(self, text:str, n_tokens_max:int=0, add_special:bool=False, parse_special:bool=False) -> list[int]:
|
||||
def tokenize(self, text: str, n_tokens_max: int = 0, add_special: bool = False, parse_special: bool = False) -> list[int]:
|
||||
n_tokens_max = n_tokens_max if n_tokens_max > 0 else len(self.token_ids)
|
||||
text = text.encode("utf-8")
|
||||
num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, n_tokens_max, add_special, parse_special)
|
||||
|
@ -91,14 +92,14 @@ class LibLlamaModel:
|
|||
return list(self.token_ids[0:num])
|
||||
|
||||
|
||||
def find_first_mismatch(ids1:list[int], ids2:list[int]):
|
||||
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
||||
for i, (a,b) in enumerate(zip(ids1, ids2)):
|
||||
if a != b:
|
||||
return i
|
||||
return -1 if len(ids1) == len(ids2) else i
|
||||
|
||||
|
||||
def test_custom_texts(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase):
|
||||
def test_custom_texts(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase):
|
||||
|
||||
tests = [
|
||||
"",
|
||||
|
@ -153,7 +154,7 @@ def test_custom_texts(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase):
|
|||
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
|
||||
]
|
||||
|
||||
for text in tests+more_tests:
|
||||
for text in tests + more_tests:
|
||||
ids1 = model.tokenize(text, parse_special=True)
|
||||
ids2 = tokenizer.encode(text)
|
||||
logger.info(repr(text))
|
||||
|
@ -164,9 +165,9 @@ def test_custom_texts(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase):
|
|||
raise Exception()
|
||||
|
||||
|
||||
def test_random_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, iterations=100):
|
||||
def test_random_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations=100):
|
||||
|
||||
WHITESPACES = list(" "*20 + "\n"*5 + "\r\n"*5 + "\t"*5)
|
||||
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
||||
CHARS = list(set("""
|
||||
ABCDEFGHIJKLMNOPQRSTUVWXYZ
|
||||
abcdefghijklmnopqrstuvwxyz
|
||||
|
@ -179,7 +180,7 @@ def test_random_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, it
|
|||
rand = random.Random()
|
||||
for m in range(iterations):
|
||||
|
||||
logger.debug("%d/%d" % (m+1,iterations))
|
||||
logger.debug("%d/%d" % (m + 1, iterations))
|
||||
rand.seed(m)
|
||||
|
||||
text = []
|
||||
|
@ -188,7 +189,7 @@ def test_random_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, it
|
|||
k = rand.randint(1, 7)
|
||||
word = rand.choices(CHARS, k=k)
|
||||
space = rand.choice(WHITESPACES)
|
||||
text.append("".join(word)+space)
|
||||
text.append("".join(word) + space)
|
||||
text = "".join(text)
|
||||
|
||||
ids1 = model.tokenize(text, parse_special=True)
|
||||
|
@ -196,7 +197,7 @@ def test_random_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, it
|
|||
assert(ids1 == ids2)
|
||||
|
||||
|
||||
def test_random_vocab_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, iterations=100):
|
||||
def test_random_vocab_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations=100):
|
||||
|
||||
logger.info("Building vocab char list ...")
|
||||
vocab_ids = list(tokenizer.vocab.values())
|
||||
|
@ -208,7 +209,7 @@ def test_random_vocab_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBa
|
|||
rand = random.Random()
|
||||
for m in range(iterations):
|
||||
|
||||
logger.debug("%d/%d" % (m+1,iterations))
|
||||
logger.debug("%d/%d" % (m + 1, iterations))
|
||||
rand.seed(m)
|
||||
|
||||
text = rand.choices(vocab_chars, k=1024)
|
||||
|
@ -219,7 +220,7 @@ def test_random_vocab_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBa
|
|||
assert(ids1 == ids2)
|
||||
|
||||
|
||||
def test_random_vocab_tokens(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, iterations=100):
|
||||
def test_random_vocab_tokens(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
|
||||
|
||||
logger.info("Building token list ...")
|
||||
space_id = tokenizer.encode(" ")[0]
|
||||
|
@ -241,7 +242,7 @@ def test_random_vocab_tokens(model:LibLlamaModel, tokenizer:PreTrainedTokenizerB
|
|||
rand = random.Random()
|
||||
for m in range(iterations):
|
||||
|
||||
logger.debug("%d/%d" % (m+1,iterations))
|
||||
logger.debug("%d/%d" % (m + 1, iterations))
|
||||
rand.seed(m)
|
||||
|
||||
text = []
|
||||
|
@ -249,7 +250,7 @@ def test_random_vocab_tokens(model:LibLlamaModel, tokenizer:PreTrainedTokenizerB
|
|||
for i in range(num_words):
|
||||
k = rand.randint(1, 3)
|
||||
tokens = rand.choices(vocab_tokens, k=k)
|
||||
tokens = [ t.strip(" \n\r\t") for t in tokens ]
|
||||
tokens = [t.strip(" \n\r\t") for t in tokens]
|
||||
sep = rand.choice(" \n\r\t")
|
||||
text.append("".join(tokens) + sep)
|
||||
text = "".join(text)
|
||||
|
@ -259,15 +260,15 @@ def test_random_vocab_tokens(model:LibLlamaModel, tokenizer:PreTrainedTokenizerB
|
|||
assert(ids1 == ids2)
|
||||
|
||||
|
||||
def test_random_bytes(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, iterations=100):
|
||||
def test_random_bytes(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations=100):
|
||||
|
||||
WHITESPACES = list(" "*20 + "\n"*5 + "\r\n"*5 + "\t"*5)
|
||||
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
||||
|
||||
logger.info("Bruteforce random bytes encodings ...")
|
||||
rand = random.Random()
|
||||
for m in range(iterations):
|
||||
|
||||
logger.debug("%d/%d" % (m+1,iterations))
|
||||
logger.debug("%d/%d" % (m + 1, iterations))
|
||||
rand.seed(m)
|
||||
|
||||
text = []
|
||||
|
@ -302,6 +303,6 @@ if __name__ == "__main__":
|
|||
test_random_chars(model, tokenizer, 10_000)
|
||||
test_random_vocab_chars(model, tokenizer, 10_000)
|
||||
test_random_vocab_tokens(model, tokenizer, 10_000)
|
||||
#test_random_bytes(model, tokenizer, 10_000) # FAIL
|
||||
# test_random_bytes(model, tokenizer, 10_000) # FAIL
|
||||
|
||||
model.free()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue