Convert unsupported datatypes to f32 when converting BERT architectures to GGUF

This commit is contained in:
Christian Azinn 2024-04-26 18:26:21 -04:00
parent 928e0b7013
commit 5ae78a1f23

View file

@ -2482,6 +2482,10 @@ class BertModel(Model):
print(f"Can not map tensor {name!r}")
sys.exit()
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
data = data_torch.squeeze().numpy()
n_dims = len(data.shape)
new_dtype: type[np.floating[Any]]