From b9276f3b59a5813d432cde53fdb6c9fcb3450122 Mon Sep 17 00:00:00 2001 From: Galunid Date: Fri, 24 Nov 2023 15:19:46 +0100 Subject: [PATCH] Use mmap in torch load, prefer .bin files when loading --- convert-hf-to-gguf.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 1105670c1..9dc8140a9 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -41,8 +41,8 @@ class Model: self.fname_out = fname_out self.is_big_endian = is_big_endian self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE - self.is_safetensors = self._is_model_safetensors() - self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin") + self.is_pytorch = self._is_model_pytorch() + self.num_parts = Model.count_model_parts(self.dir_model, ".bin" if self.is_pytorch else ".safetensors") self.part_names = self._get_part_names() self.hparams = Model.load_hparams(self.dir_model) self.model_arch = self._get_model_architecture() @@ -55,15 +55,15 @@ class Model: for part_name in self.part_names: print(f"gguf: loading model part '{part_name}'") ctx: ContextManager[Any] - if self.is_safetensors: + if self.is_pytorch: + ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) + else: from safetensors import safe_open ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) - else: - ctx = contextlib.nullcontext(torch.load(self.dir_model / part_name, map_location="cpu")) with ctx as model_part: for name in model_part.keys(): - data = model_part.get_tensor(name) if self.is_safetensors else model_part[name] + data = model_part[name] if self.is_pytorch else model_part.get_tensor(name) yield name, data def set_gguf_parameters(self): @@ -170,18 +170,18 @@ class Model: return StableLMModel return Model - def _is_model_safetensors(self) -> bool: - return Model.count_model_parts(self.dir_model, ".safetensors") > 0 + def _is_model_pytorch(self) -> bool: + return Model.count_model_parts(self.dir_model, ".bin") > 0 def _get_part_names(self): - if self.is_safetensors: - if self.num_parts == 1: # there's only one .safetensors file - return ("model.safetensors",) - return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1)) + if self.is_pytorch: + if self.num_parts == 1: # there's only one .bin file + return ("pytorch_model.bin",) + return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1)) - if self.num_parts == 1: # there's only one .bin file - return ("pytorch_model.bin",) - return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1)) + if self.num_parts == 1: # there's only one .safetensors file + return ("model.safetensors",) + return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1)) def _get_model_architecture(self) -> gguf.MODEL_ARCH: arch = self.hparams["architectures"][0]