Refactor dtype handling to be extensible

This code is equivalent as before, but now it is prepared to easily add
more NumPy dtypes.
This commit is contained in:
Ondřej Čertík 2024-03-13 12:21:24 -06:00
parent 463628372d
commit b7e9d5c8d4

View file

@ -196,9 +196,6 @@ class GGUFWriter:
if self.state is not WriterState.EMPTY: if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}') raise ValueError(f'Expected output file to be empty, got {self.state}')
if raw_dtype is None and tensor_dtype not in (np.float32, np.float16):
raise ValueError("Only F32 and F16 tensors are supported for now")
encoded_name = name.encode("utf8") encoded_name = name.encode("utf8")
self.ti_data += self._pack("Q", len(encoded_name)) self.ti_data += self._pack("Q", len(encoded_name))
self.ti_data += encoded_name self.ti_data += encoded_name
@ -207,7 +204,12 @@ class GGUFWriter:
for i in range(n_dims): for i in range(n_dims):
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
if raw_dtype is None: if raw_dtype is None:
dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16 if tensor_shape == np.float32:
dtype = GGMLQuantizationType.F32
elif tensor_dtype == np.float16:
dtype = GGMLQuantizationType.F16
else:
raise ValueError("Only F32 and F16 tensors are supported for now")
else: else:
dtype = raw_dtype dtype = raw_dtype
self.ti_data += self._pack("I", dtype) self.ti_data += self._pack("I", dtype)