Merge 6f9d1275a0
into 6026da52d6
This commit is contained in:
commit
b276c09b6f
1 changed files with 11 additions and 0 deletions
|
@ -164,8 +164,19 @@ class Model:
|
|||
for name in model_part.keys():
|
||||
if self.is_safetensors:
|
||||
if self.lazy:
|
||||
if (name.endswith("_scale") and name.removesuffix("_scale") in model_part.keys()):
|
||||
continue
|
||||
data = model_part.get_slice(name)
|
||||
data = LazyTorchTensor.from_safetensors_slice(data)
|
||||
if (name + "_scale" in model_part.keys()):
|
||||
orig_shape = data.shape
|
||||
scale = model_part.get_slice(name + "_scale")
|
||||
shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape)))))
|
||||
data = data.unsqueeze(0).expand((4, *orig_shape)) >> shift
|
||||
data = data & 3
|
||||
data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:]))
|
||||
# The scale is inverted
|
||||
data = data / LazyTorchTensor.from_safetensors_slice(scale).float()
|
||||
else:
|
||||
data = model_part.get_tensor(name)
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue