fix: update torch version

This commit is contained in:
namtranase 2024-01-02 14:58:40 +07:00
parent 6c46cb1da4
commit 9e02214e98
2 changed files with 2 additions and 2 deletions

View file

@ -1,2 +1,2 @@
torch>=2.0.0 torch>=2.1.1
transformers>=4.32.0 transformers>=4.32.0

View file

@ -59,7 +59,7 @@ class Model:
from safetensors import safe_open from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
else: else:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", weights_only=True)) ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
with ctx as model_part: with ctx as model_part:
for name in model_part.keys(): for name in model_part.keys():