Refactor dtype handling to be extensible
This code is equivalent as before, but now it is prepared to easily add more NumPy dtypes.
This commit is contained in:
parent
463628372d
commit
b7e9d5c8d4
1 changed files with 6 additions and 4 deletions
|
@ -196,9 +196,6 @@ class GGUFWriter:
|
||||||
if self.state is not WriterState.EMPTY:
|
if self.state is not WriterState.EMPTY:
|
||||||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
||||||
|
|
||||||
if raw_dtype is None and tensor_dtype not in (np.float32, np.float16):
|
|
||||||
raise ValueError("Only F32 and F16 tensors are supported for now")
|
|
||||||
|
|
||||||
encoded_name = name.encode("utf8")
|
encoded_name = name.encode("utf8")
|
||||||
self.ti_data += self._pack("Q", len(encoded_name))
|
self.ti_data += self._pack("Q", len(encoded_name))
|
||||||
self.ti_data += encoded_name
|
self.ti_data += encoded_name
|
||||||
|
@ -207,7 +204,12 @@ class GGUFWriter:
|
||||||
for i in range(n_dims):
|
for i in range(n_dims):
|
||||||
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
|
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
|
||||||
if raw_dtype is None:
|
if raw_dtype is None:
|
||||||
dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
|
if tensor_shape == np.float32:
|
||||||
|
dtype = GGMLQuantizationType.F32
|
||||||
|
elif tensor_dtype == np.float16:
|
||||||
|
dtype = GGMLQuantizationType.F16
|
||||||
|
else:
|
||||||
|
raise ValueError("Only F32 and F16 tensors are supported for now")
|
||||||
else:
|
else:
|
||||||
dtype = raw_dtype
|
dtype = raw_dtype
|
||||||
self.ti_data += self._pack("I", dtype)
|
self.ti_data += self._pack("I", dtype)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue