gguf-py : always defer GGUFWrite output file opening
Changing what happens when the output file is opened will be easier, since this reduces the cases to consider. * gguf-py : prevent GGUFWriter from writing all tensors multiple times It was already checked with an assertion before, but using WriterState should make the error message slightly less cryptic.
This commit is contained in:
parent
fe59f20d26
commit
32d11dbbe8
1 changed files with 32 additions and 17 deletions
|
@ -51,10 +51,12 @@ class WriterState(Enum):
|
|||
HEADER = auto()
|
||||
KV_DATA = auto()
|
||||
TI_DATA = auto()
|
||||
WEIGHTS = auto()
|
||||
|
||||
|
||||
class GGUFWriter:
|
||||
fout: BufferedWriter | None
|
||||
path: os.PathLike[str] | str | None
|
||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
||||
tensors: dict[str, TensorInfo]
|
||||
kv_data: dict[str, GGUFValue]
|
||||
|
@ -77,7 +79,8 @@ class GGUFWriter:
|
|||
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False,
|
||||
endianess: GGUFEndian = GGUFEndian.LITTLE,
|
||||
):
|
||||
self.fout = open(path, "wb") if path is not None else None
|
||||
self.fout = None
|
||||
self.path = path
|
||||
self.arch = arch
|
||||
self.endianess = endianess
|
||||
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
||||
|
@ -88,19 +91,29 @@ class GGUFWriter:
|
|||
logger.info("gguf: This GGUF file is for {0} Endian only".format(
|
||||
"Big" if self.endianess == GGUFEndian.BIG else "Little",
|
||||
))
|
||||
self.state = WriterState.NO_FILE if self.fout is None else WriterState.EMPTY
|
||||
self.state = WriterState.NO_FILE
|
||||
|
||||
self.add_architecture()
|
||||
|
||||
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
# NOTE: not checking for WriterState.NO_FILE,
|
||||
# because writing can technically be started over from any state,
|
||||
# as long as a new path is provided
|
||||
def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
|
||||
# allow calling this multiple times as long as the path is the same
|
||||
return
|
||||
if self.state is not WriterState.NO_FILE:
|
||||
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
|
||||
|
||||
if path is not None:
|
||||
self.path = path
|
||||
|
||||
if self.path is not None:
|
||||
if self.fout is not None:
|
||||
self.fout.close()
|
||||
self.fout = open(path, "wb")
|
||||
self.fout = open(self.path, "wb")
|
||||
self.state = WriterState.EMPTY
|
||||
|
||||
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
self.open_output_file(path)
|
||||
|
||||
if self.state is not WriterState.EMPTY:
|
||||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
||||
|
||||
|
@ -206,8 +219,8 @@ class GGUFWriter:
|
|||
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
|
||||
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
|
||||
) -> None:
|
||||
if self.state is not WriterState.EMPTY and self.state is not WriterState.NO_FILE:
|
||||
raise ValueError(f'Expected output file to be empty or absent, got {self.state}')
|
||||
if self.state is not WriterState.NO_FILE:
|
||||
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
|
||||
|
||||
if name in self.tensors:
|
||||
raise ValueError(f'Duplicated tensor name {name!r}')
|
||||
|
@ -263,8 +276,8 @@ class GGUFWriter:
|
|||
fp.write(bytes([0] * pad))
|
||||
|
||||
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
||||
if self.state is not WriterState.TI_DATA:
|
||||
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
|
||||
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
|
||||
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
|
||||
assert self.fout is not None
|
||||
|
||||
if self.endianess == GGUFEndian.BIG:
|
||||
|
@ -273,6 +286,8 @@ class GGUFWriter:
|
|||
tensor.tofile(self.fout)
|
||||
self.write_padding(self.fout, tensor.nbytes)
|
||||
|
||||
self.state = WriterState.WEIGHTS
|
||||
|
||||
def write_tensors_to_file(self, *, progress: bool = False) -> None:
|
||||
self.write_ti_data_to_file()
|
||||
|
||||
|
@ -299,14 +314,14 @@ class GGUFWriter:
|
|||
bar.update(ti.nbytes)
|
||||
self.write_padding(self.fout, ti.nbytes)
|
||||
ti.tensor = None
|
||||
else:
|
||||
self.temp_file.seek(0)
|
||||
|
||||
return
|
||||
shutil.copyfileobj(self.temp_file, self.fout)
|
||||
self.flush()
|
||||
self.temp_file.close()
|
||||
|
||||
self.temp_file.seek(0)
|
||||
|
||||
shutil.copyfileobj(self.temp_file, self.fout)
|
||||
self.flush()
|
||||
self.temp_file.close()
|
||||
self.state = WriterState.WEIGHTS
|
||||
|
||||
def flush(self) -> None:
|
||||
assert self.fout is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue