fix: update torch version
This commit is contained in:
parent
6c46cb1da4
commit
9e02214e98
2 changed files with 2 additions and 2 deletions
|
@ -1,2 +1,2 @@
|
||||||
torch>=2.0.0
|
torch>=2.1.1
|
||||||
transformers>=4.32.0
|
transformers>=4.32.0
|
||||||
|
|
|
@ -59,7 +59,7 @@ class Model:
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
||||||
else:
|
else:
|
||||||
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", weights_only=True))
|
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
||||||
|
|
||||||
with ctx as model_part:
|
with ctx as model_part:
|
||||||
for name in model_part.keys():
|
for name in model_part.keys():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue