Use mmap in torch load, prefer .bin files when loading

This commit is contained in:
Galunid 2023-11-24 15:19:46 +01:00
parent 8e672efe63
commit b9276f3b59

View file

@ -41,8 +41,8 @@ class Model:
self.fname_out = fname_out self.fname_out = fname_out
self.is_big_endian = is_big_endian self.is_big_endian = is_big_endian
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.is_safetensors = self._is_model_safetensors() self.is_pytorch = self._is_model_pytorch()
self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin") self.num_parts = Model.count_model_parts(self.dir_model, ".bin" if self.is_pytorch else ".safetensors")
self.part_names = self._get_part_names() self.part_names = self._get_part_names()
self.hparams = Model.load_hparams(self.dir_model) self.hparams = Model.load_hparams(self.dir_model)
self.model_arch = self._get_model_architecture() self.model_arch = self._get_model_architecture()
@ -55,15 +55,15 @@ class Model:
for part_name in self.part_names: for part_name in self.part_names:
print(f"gguf: loading model part '{part_name}'") print(f"gguf: loading model part '{part_name}'")
ctx: ContextManager[Any] ctx: ContextManager[Any]
if self.is_safetensors: if self.is_pytorch:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
else:
from safetensors import safe_open from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
else:
ctx = contextlib.nullcontext(torch.load(self.dir_model / part_name, map_location="cpu"))
with ctx as model_part: with ctx as model_part:
for name in model_part.keys(): for name in model_part.keys():
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name] data = model_part[name] if self.is_pytorch else model_part.get_tensor(name)
yield name, data yield name, data
def set_gguf_parameters(self): def set_gguf_parameters(self):
@ -170,19 +170,19 @@ class Model:
return StableLMModel return StableLMModel
return Model return Model
def _is_model_safetensors(self) -> bool: def _is_model_pytorch(self) -> bool:
return Model.count_model_parts(self.dir_model, ".safetensors") > 0 return Model.count_model_parts(self.dir_model, ".bin") > 0
def _get_part_names(self): def _get_part_names(self):
if self.is_safetensors: if self.is_pytorch:
if self.num_parts == 1: # there's only one .safetensors file
return ("model.safetensors",)
return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
if self.num_parts == 1: # there's only one .bin file if self.num_parts == 1: # there's only one .bin file
return ("pytorch_model.bin",) return ("pytorch_model.bin",)
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1)) return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
if self.num_parts == 1: # there's only one .safetensors file
return ("model.safetensors",)
return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
def _get_model_architecture(self) -> gguf.MODEL_ARCH: def _get_model_architecture(self) -> gguf.MODEL_ARCH:
arch = self.hparams["architectures"][0] arch = self.hparams["architectures"][0]
if arch == "GPTNeoXForCausalLM": if arch == "GPTNeoXForCausalLM":