This commit is contained in:
Bruno Pio 2024-09-19 10:45:01 +01:00 committed by GitHub
commit b276c09b6f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -164,8 +164,19 @@ class Model:
for name in model_part.keys():
if self.is_safetensors:
if self.lazy:
if (name.endswith("_scale") and name.removesuffix("_scale") in model_part.keys()):
continue
data = model_part.get_slice(name)
data = LazyTorchTensor.from_safetensors_slice(data)
if (name + "_scale" in model_part.keys()):
orig_shape = data.shape
scale = model_part.get_slice(name + "_scale")
shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape)))))
data = data.unsqueeze(0).expand((4, *orig_shape)) >> shift
data = data & 3
data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:]))
# The scale is inverted
data = data / LazyTorchTensor.from_safetensors_slice(scale).float()
else:
data = model_part.get_tensor(name)
else: