gguf-py : always defer GGUFWrite output file opening

Changing what happens when the output file is opened will be easier,
since this reduces the cases to consider.

* gguf-py : prevent GGUFWriter from writing all tensors multiple times

It was already checked with an assertion before, but using WriterState
should make the error message slightly less cryptic.
This commit is contained in:
Francis Couture-Harpin 2024-06-08 12:29:15 -04:00
parent fe59f20d26
commit 32d11dbbe8

View file

@ -51,10 +51,12 @@ class WriterState(Enum):
HEADER = auto()
KV_DATA = auto()
TI_DATA = auto()
WEIGHTS = auto()
class GGUFWriter:
fout: BufferedWriter | None
path: os.PathLike[str] | str | None
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: dict[str, TensorInfo]
kv_data: dict[str, GGUFValue]
@ -77,7 +79,8 @@ class GGUFWriter:
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False,
endianess: GGUFEndian = GGUFEndian.LITTLE,
):
self.fout = open(path, "wb") if path is not None else None
self.fout = None
self.path = path
self.arch = arch
self.endianess = endianess
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
@ -88,19 +91,29 @@ class GGUFWriter:
logger.info("gguf: This GGUF file is for {0} Endian only".format(
"Big" if self.endianess == GGUFEndian.BIG else "Little",
))
self.state = WriterState.NO_FILE if self.fout is None else WriterState.EMPTY
self.state = WriterState.NO_FILE
self.add_architecture()
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
# NOTE: not checking for WriterState.NO_FILE,
# because writing can technically be started over from any state,
# as long as a new path is provided
def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None:
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
# allow calling this multiple times as long as the path is the same
return
if self.state is not WriterState.NO_FILE:
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
if path is not None:
self.path = path
if self.path is not None:
if self.fout is not None:
self.fout.close()
self.fout = open(path, "wb")
self.fout = open(self.path, "wb")
self.state = WriterState.EMPTY
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
self.open_output_file(path)
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}')
@ -206,8 +219,8 @@ class GGUFWriter:
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
) -> None:
if self.state is not WriterState.EMPTY and self.state is not WriterState.NO_FILE:
raise ValueError(f'Expected output file to be empty or absent, got {self.state}')
if self.state is not WriterState.NO_FILE:
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
if name in self.tensors:
raise ValueError(f'Duplicated tensor name {name!r}')
@ -263,8 +276,8 @@ class GGUFWriter:
fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
if self.state is not WriterState.TI_DATA:
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
assert self.fout is not None
if self.endianess == GGUFEndian.BIG:
@ -273,6 +286,8 @@ class GGUFWriter:
tensor.tofile(self.fout)
self.write_padding(self.fout, tensor.nbytes)
self.state = WriterState.WEIGHTS
def write_tensors_to_file(self, *, progress: bool = False) -> None:
self.write_ti_data_to_file()
@ -299,14 +314,14 @@ class GGUFWriter:
bar.update(ti.nbytes)
self.write_padding(self.fout, ti.nbytes)
ti.tensor = None
else:
self.temp_file.seek(0)
return
shutil.copyfileobj(self.temp_file, self.fout)
self.flush()
self.temp_file.close()
self.temp_file.seek(0)
shutil.copyfileobj(self.temp_file, self.fout)
self.flush()
self.temp_file.close()
self.state = WriterState.WEIGHTS
def flush(self) -> None:
assert self.fout is not None