From 0f73d87dbe51853c01fb88e9cb3126e5f91424a7 Mon Sep 17 00:00:00 2001 From: Galunid Date: Fri, 24 Nov 2023 16:12:47 +0100 Subject: [PATCH] Revert .bin > .safetensors preference --- convert-hf-to-gguf.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 9dc8140a9..d223d5190 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -41,8 +41,8 @@ class Model: self.fname_out = fname_out self.is_big_endian = is_big_endian self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE - self.is_pytorch = self._is_model_pytorch() - self.num_parts = Model.count_model_parts(self.dir_model, ".bin" if self.is_pytorch else ".safetensors") + self.is_safetensors = self._is_model_safetensors() + self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin") self.part_names = self._get_part_names() self.hparams = Model.load_hparams(self.dir_model) self.model_arch = self._get_model_architecture() @@ -55,15 +55,15 @@ class Model: for part_name in self.part_names: print(f"gguf: loading model part '{part_name}'") ctx: ContextManager[Any] - if self.is_pytorch: - ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) - else: + if self.is_safetensors: from safetensors import safe_open ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) + else: + ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) with ctx as model_part: for name in model_part.keys(): - data = model_part[name] if self.is_pytorch else model_part.get_tensor(name) + data = model_part.get_tensor(name) if self.is_safetensors else model_part[name] yield name, data def set_gguf_parameters(self): @@ -170,18 +170,18 @@ class Model: return StableLMModel return Model - def _is_model_pytorch(self) -> bool: - return Model.count_model_parts(self.dir_model, ".bin") > 0 + def _is_model_safetensors(self) -> bool: + return Model.count_model_parts(self.dir_model, ".safetensors") > 0 def _get_part_names(self): - if self.is_pytorch: - if self.num_parts == 1: # there's only one .bin file - return ("pytorch_model.bin",) - return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1)) + if self.is_safetensors: + if self.num_parts == 1: # there's only one .safetensors file + return ("model.safetensors",) + return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1)) - if self.num_parts == 1: # there's only one .safetensors file - return ("model.safetensors",) - return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1)) + if self.num_parts == 1: # there's only one .bin file + return ("pytorch_model.bin",) + return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1)) def _get_model_architecture(self) -> gguf.MODEL_ARCH: arch = self.hparams["architectures"][0]