Update bruteforce test: header files location

This commit is contained in:
jaime-m-p 2024-07-04 22:35:21 +02:00
parent 2f150197e4
commit 11ac641c1e

View file

@ -24,17 +24,20 @@ logger = logging.getLogger("test-tokenizer-random")
class LibLlama:
DEFAULT_PATH_LLAMA_H = "./llama.h"
DEFAULT_PATH_LIBLLAMA = "./build/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON
DEFAULT_PATH_LLAMA_H = "./include/llama.h"
DEFAULT_PATH_INCLUDES = [ "./ggml/include/", "./include/" ]
DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON
def __init__(self, path_llama_h: str = None, path_libllama: str = None):
path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
def __init__(self, path_llama_h: str = None, path_includes: list[str] = [], path_libllama: str = None):
path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
path_includes = path_includes or self.DEFAULT_PATH_INCLUDES
path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
(self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_libllama)
(self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_includes, path_libllama)
self.lib.llama_backend_init()
def _load_libllama_cffi(self, path_llama_h: str, path_libllama: str):
cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)=", path_llama_h]
def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str):
cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="]
cmd += ["-I" + path for path in path_includes] + [path_llama_h]
res = subprocess.run(cmd, stdout=subprocess.PIPE)
assert (res.returncode == 0)
source = res.stdout.decode()