Revert .bin > .safetensors preference
This commit is contained in:
parent
b9276f3b59
commit
0f73d87dbe
1 changed files with 15 additions and 15 deletions
|
@ -41,8 +41,8 @@ class Model:
|
||||||
self.fname_out = fname_out
|
self.fname_out = fname_out
|
||||||
self.is_big_endian = is_big_endian
|
self.is_big_endian = is_big_endian
|
||||||
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
||||||
self.is_pytorch = self._is_model_pytorch()
|
self.is_safetensors = self._is_model_safetensors()
|
||||||
self.num_parts = Model.count_model_parts(self.dir_model, ".bin" if self.is_pytorch else ".safetensors")
|
self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
|
||||||
self.part_names = self._get_part_names()
|
self.part_names = self._get_part_names()
|
||||||
self.hparams = Model.load_hparams(self.dir_model)
|
self.hparams = Model.load_hparams(self.dir_model)
|
||||||
self.model_arch = self._get_model_architecture()
|
self.model_arch = self._get_model_architecture()
|
||||||
|
@ -55,15 +55,15 @@ class Model:
|
||||||
for part_name in self.part_names:
|
for part_name in self.part_names:
|
||||||
print(f"gguf: loading model part '{part_name}'")
|
print(f"gguf: loading model part '{part_name}'")
|
||||||
ctx: ContextManager[Any]
|
ctx: ContextManager[Any]
|
||||||
if self.is_pytorch:
|
if self.is_safetensors:
|
||||||
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
|
||||||
else:
|
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
||||||
|
else:
|
||||||
|
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
||||||
|
|
||||||
with ctx as model_part:
|
with ctx as model_part:
|
||||||
for name in model_part.keys():
|
for name in model_part.keys():
|
||||||
data = model_part[name] if self.is_pytorch else model_part.get_tensor(name)
|
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
|
||||||
yield name, data
|
yield name, data
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
|
@ -170,19 +170,19 @@ class Model:
|
||||||
return StableLMModel
|
return StableLMModel
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
def _is_model_pytorch(self) -> bool:
|
def _is_model_safetensors(self) -> bool:
|
||||||
return Model.count_model_parts(self.dir_model, ".bin") > 0
|
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
|
||||||
|
|
||||||
def _get_part_names(self):
|
def _get_part_names(self):
|
||||||
if self.is_pytorch:
|
if self.is_safetensors:
|
||||||
if self.num_parts == 1: # there's only one .bin file
|
|
||||||
return ("pytorch_model.bin",)
|
|
||||||
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
|
|
||||||
|
|
||||||
if self.num_parts == 1: # there's only one .safetensors file
|
if self.num_parts == 1: # there's only one .safetensors file
|
||||||
return ("model.safetensors",)
|
return ("model.safetensors",)
|
||||||
return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
|
return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
|
||||||
|
|
||||||
|
if self.num_parts == 1: # there's only one .bin file
|
||||||
|
return ("pytorch_model.bin",)
|
||||||
|
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
|
||||||
|
|
||||||
def _get_model_architecture(self) -> gguf.MODEL_ARCH:
|
def _get_model_architecture(self) -> gguf.MODEL_ARCH:
|
||||||
arch = self.hparams["architectures"][0]
|
arch = self.hparams["architectures"][0]
|
||||||
if arch == "GPTNeoXForCausalLM":
|
if arch == "GPTNeoXForCausalLM":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue