This commit is contained in:
jaime-m-p 2024-05-07 20:38:00 +02:00
parent edf375d26f
commit a5fa2fec60

View file

@ -21,7 +21,7 @@ class LibLlama:
DEFAULT_PATH_LLAMA_H = "./llama.h" DEFAULT_PATH_LLAMA_H = "./llama.h"
DEFAULT_PATH_LIBLLAMA = "./build/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON DEFAULT_PATH_LIBLLAMA = "./build/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON
def __init__(self, path_llama_h:str=None, path_libllama:str=None): def __init__(self, path_llama_h: str = None, path_libllama: str = None):
path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
(self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_libllama) (self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_libllama)
@ -42,34 +42,35 @@ class LibLlama:
ffi.cdef(source, override=True) ffi.cdef(source, override=True)
lib = ffi.dlopen(path_libllama) lib = ffi.dlopen(path_libllama)
return (ffi, lib) return (ffi, lib)
def model_default_params(self, **kwargs): def model_default_params(self, **kwargs):
mparams = self.lib.llama_model_default_params() mparams = self.lib.llama_model_default_params()
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(mparams, k, v) setattr(mparams, k, v)
return mparams return mparams
def context_default_params(self, **kwargs): def context_default_params(self, **kwargs):
cparams = self.lib.llama_context_default_params() cparams = self.lib.llama_context_default_params()
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(cparams, k, v) setattr(cparams, k, v)
return cparams return cparams
class LibLlamaModel: class LibLlamaModel:
def __init__(self, libllama:LibLlama, path_model:str, mparams={}, cparams={}): def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
self.lib = libllama.lib self.lib = libllama.lib
self.ffi = libllama.ffi self.ffi = libllama.ffi
if type(mparams) == dict: if type(mparams) == dict:
mparams = libllama.model_default_params(**mparams) mparams = libllama.model_default_params(**mparams)
self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams) self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams)
if not self.model: if not self.model:
raise RuntimeError("error: failed to load model '%s'"%path_model) raise RuntimeError("error: failed to load model '%s'" % path_model)
if type(cparams) == dict: if type(cparams) == dict:
cparams = libllama.context_default_params(**cparams) cparams = libllama.context_default_params(**cparams)
self.ctx = self.lib.llama_new_context_with_model(self.model, cparams) self.ctx = self.lib.llama_new_context_with_model(self.model, cparams)
if not self.ctx: if not self.ctx:
raise RuntimeError("error: failed to create context for model '%s'"%path_model) raise RuntimeError("error: failed to create context for model '%s'" % path_model)
n_tokens_max = self.lib.llama_n_ctx(self.ctx) n_tokens_max = self.lib.llama_n_ctx(self.ctx)
self.token_ids = self.ffi.new("llama_token[]", n_tokens_max) self.token_ids = self.ffi.new("llama_token[]", n_tokens_max)
@ -82,7 +83,7 @@ class LibLlamaModel:
self.model = None self.model = None
self.lib = None self.lib = None
def tokenize(self, text:str, n_tokens_max:int=0, add_special:bool=False, parse_special:bool=False) -> list[int]: def tokenize(self, text: str, n_tokens_max: int = 0, add_special: bool = False, parse_special: bool = False) -> list[int]:
n_tokens_max = n_tokens_max if n_tokens_max > 0 else len(self.token_ids) n_tokens_max = n_tokens_max if n_tokens_max > 0 else len(self.token_ids)
text = text.encode("utf-8") text = text.encode("utf-8")
num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, n_tokens_max, add_special, parse_special) num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, n_tokens_max, add_special, parse_special)
@ -91,14 +92,14 @@ class LibLlamaModel:
return list(self.token_ids[0:num]) return list(self.token_ids[0:num])
def find_first_mismatch(ids1:list[int], ids2:list[int]): def find_first_mismatch(ids1: list[int], ids2: list[int]):
for i, (a,b) in enumerate(zip(ids1, ids2)): for i, (a,b) in enumerate(zip(ids1, ids2)):
if a != b: if a != b:
return i return i
return -1 if len(ids1) == len(ids2) else i return -1 if len(ids1) == len(ids2) else i
def test_custom_texts(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase): def test_custom_texts(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase):
tests = [ tests = [
"", "",
@ -153,7 +154,7 @@ def test_custom_texts(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase):
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM) '\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
] ]
for text in tests+more_tests: for text in tests + more_tests:
ids1 = model.tokenize(text, parse_special=True) ids1 = model.tokenize(text, parse_special=True)
ids2 = tokenizer.encode(text) ids2 = tokenizer.encode(text)
logger.info(repr(text)) logger.info(repr(text))
@ -164,9 +165,9 @@ def test_custom_texts(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase):
raise Exception() raise Exception()
def test_random_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, iterations=100): def test_random_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations=100):
WHITESPACES = list(" "*20 + "\n"*5 + "\r\n"*5 + "\t"*5) WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
CHARS = list(set(""" CHARS = list(set("""
ABCDEFGHIJKLMNOPQRSTUVWXYZ ABCDEFGHIJKLMNOPQRSTUVWXYZ
abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstuvwxyz
@ -174,12 +175,12 @@ def test_random_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, it
áéíóúàèìòùâêîôûäëïöü áéíóúàèìòùâêîôûäëïöü
.-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_ .-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
""")) """))
logger.info("Bruteforce random chars encodings ...") logger.info("Bruteforce random chars encodings ...")
rand = random.Random() rand = random.Random()
for m in range(iterations): for m in range(iterations):
logger.debug("%d/%d" % (m+1,iterations)) logger.debug("%d/%d" % (m + 1, iterations))
rand.seed(m) rand.seed(m)
text = [] text = []
@ -188,7 +189,7 @@ def test_random_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, it
k = rand.randint(1, 7) k = rand.randint(1, 7)
word = rand.choices(CHARS, k=k) word = rand.choices(CHARS, k=k)
space = rand.choice(WHITESPACES) space = rand.choice(WHITESPACES)
text.append("".join(word)+space) text.append("".join(word) + space)
text = "".join(text) text = "".join(text)
ids1 = model.tokenize(text, parse_special=True) ids1 = model.tokenize(text, parse_special=True)
@ -196,21 +197,21 @@ def test_random_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, it
assert(ids1 == ids2) assert(ids1 == ids2)
def test_random_vocab_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, iterations=100): def test_random_vocab_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations=100):
logger.info("Building vocab char list ...") logger.info("Building vocab char list ...")
vocab_ids = list(tokenizer.vocab.values()) vocab_ids = list(tokenizer.vocab.values())
vocab_text = tokenizer.decode(vocab_ids) vocab_text = tokenizer.decode(vocab_ids)
vocab_chars = list(set(vocab_text)) vocab_chars = list(set(vocab_text))
del vocab_ids, vocab_text del vocab_ids, vocab_text
logger.info("Bruteforce random text encodings ...") logger.info("Bruteforce random text encodings ...")
rand = random.Random() rand = random.Random()
for m in range(iterations): for m in range(iterations):
logger.debug("%d/%d" % (m+1,iterations)) logger.debug("%d/%d" % (m + 1, iterations))
rand.seed(m) rand.seed(m)
text = rand.choices(vocab_chars, k=1024) text = rand.choices(vocab_chars, k=1024)
text = "".join(text) text = "".join(text)
@ -219,7 +220,7 @@ def test_random_vocab_chars(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBa
assert(ids1 == ids2) assert(ids1 == ids2)
def test_random_vocab_tokens(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, iterations=100): def test_random_vocab_tokens(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
logger.info("Building token list ...") logger.info("Building token list ...")
space_id = tokenizer.encode(" ")[0] space_id = tokenizer.encode(" ")[0]
@ -230,7 +231,7 @@ def test_random_vocab_tokens(model:LibLlamaModel, tokenizer:PreTrainedTokenizerB
vocab_tokens = tokenizer.decode(vocab_ids) vocab_tokens = tokenizer.decode(vocab_ids)
vocab_tokens = vocab_tokens.split(" ") vocab_tokens = vocab_tokens.split(" ")
del vocab_ids del vocab_ids
logger.info("Checking single token encodings ...") logger.info("Checking single token encodings ...")
for token in vocab_tokens: for token in vocab_tokens:
ids1 = model.tokenize(token, parse_special=True) ids1 = model.tokenize(token, parse_special=True)
@ -241,15 +242,15 @@ def test_random_vocab_tokens(model:LibLlamaModel, tokenizer:PreTrainedTokenizerB
rand = random.Random() rand = random.Random()
for m in range(iterations): for m in range(iterations):
logger.debug("%d/%d" % (m+1,iterations)) logger.debug("%d/%d" % (m + 1, iterations))
rand.seed(m) rand.seed(m)
text = [] text = []
num_words = rand.randint(300, 400) num_words = rand.randint(300, 400)
for i in range(num_words): for i in range(num_words):
k = rand.randint(1, 3) k = rand.randint(1, 3)
tokens = rand.choices(vocab_tokens, k=k) tokens = rand.choices(vocab_tokens, k=k)
tokens = [ t.strip(" \n\r\t") for t in tokens ] tokens = [t.strip(" \n\r\t") for t in tokens]
sep = rand.choice(" \n\r\t") sep = rand.choice(" \n\r\t")
text.append("".join(tokens) + sep) text.append("".join(tokens) + sep)
text = "".join(text) text = "".join(text)
@ -259,15 +260,15 @@ def test_random_vocab_tokens(model:LibLlamaModel, tokenizer:PreTrainedTokenizerB
assert(ids1 == ids2) assert(ids1 == ids2)
def test_random_bytes(model:LibLlamaModel, tokenizer:PreTrainedTokenizerBase, iterations=100): def test_random_bytes(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations=100):
WHITESPACES = list(" "*20 + "\n"*5 + "\r\n"*5 + "\t"*5) WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
logger.info("Bruteforce random bytes encodings ...") logger.info("Bruteforce random bytes encodings ...")
rand = random.Random() rand = random.Random()
for m in range(iterations): for m in range(iterations):
logger.debug("%d/%d" % (m+1,iterations)) logger.debug("%d/%d" % (m + 1, iterations))
rand.seed(m) rand.seed(m)
text = [] text = []
@ -302,6 +303,6 @@ if __name__ == "__main__":
test_random_chars(model, tokenizer, 10_000) test_random_chars(model, tokenizer, 10_000)
test_random_vocab_chars(model, tokenizer, 10_000) test_random_vocab_chars(model, tokenizer, 10_000)
test_random_vocab_tokens(model, tokenizer, 10_000) test_random_vocab_tokens(model, tokenizer, 10_000)
#test_random_bytes(model, tokenizer, 10_000) # FAIL # test_random_bytes(model, tokenizer, 10_000) # FAIL
model.free() model.free()