consolidated.safetensors

easier handling (as eg for Ministral)
This commit is contained in:
CrispStrobe 2024-10-16 23:18:56 +02:00 committed by GitHub
parent 9e04102448
commit 645eb3c6ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -83,10 +83,10 @@ class Model:
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file self.use_temp_file = use_temp_file
self.lazy = not eager self.lazy = not eager
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors") self.part_names = Model.get_model_part_names(self.dir_model, ["model"], [".safetensors"])
self.is_safetensors = len(self.part_names) > 0 self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors: if not self.is_safetensors:
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.part_names = Model.get_model_part_names(self.dir_model, ["pytorch_model"], [".bin"])
self.hparams = Model.load_hparams(self.dir_model) self.hparams = Model.load_hparams(self.dir_model)
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
@ -447,10 +447,12 @@ class Model:
self.gguf_writer.close() self.gguf_writer.close()
@staticmethod @staticmethod
def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: def get_model_part_names(dir_model: Path, prefixes: list[str], suffixes: list[str]) -> list[str]:
part_names: list[str] = [] part_names: list[str] = []
for filename in os.listdir(dir_model): for filename in os.listdir(dir_model):
if filename.startswith(prefix) and filename.endswith(suffix): if any(filename.startswith(prefix) for prefix in prefixes) and any(filename.endswith(suffix) for suffix in suffixes):
part_names.append(filename)
elif filename == "consolidated.safetensors":
part_names.append(filename) part_names.append(filename)
part_names.sort() part_names.sort()