diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 766ae86b4..0ef75ff07 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -56,14 +56,6 @@ class GGUFValue: type: GGUFValueType -@dataclass -class Shard: - path: Path - tensor_count: int - size: int - tensors: deque[TensorTempData] - - class SplitArguments: def __init__(self, args: Namespace) -> None: self.split_max_tensors = args.split_max_tensors if args.split_max_tensors else 0 @@ -91,10 +83,10 @@ class SplitStyle(Enum): class GGUFWriter: - fout: list[BufferedWriter | None] | None + fout: list[BufferedWriter] | None path: os.PathLike[str] | str | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None - tensors: list[dict[str, TensorInfo | np.ndarray[Any, Any]]] + tensors: list[dict[str, TensorInfo]] kv_data: list[dict[str, GGUFValue]] state: WriterState _simple_value_packing = { @@ -137,7 +129,7 @@ class GGUFWriter: def verify_arguments(self) -> None: total_tensors = sum(len(ti) for ti in self.tensors) - total_size = sum(sum(GGUFWriter.get_tensor_size(ti) for ti in t.values()) for t in self.tensors) + total_size = sum(ti.nbytes for t in self.tensors for ti in t.values()) if self.split_arguments.split_max_tensors and total_tensors < self.split_arguments.split_max_tensors: logger.warning("Model has fewer tensors than the split threshold, not splitting") @@ -149,10 +141,10 @@ class GGUFWriter: # no shards are created when writing vocab so make one if not self.tensors or len(self.tensors) == 0: - self.tensors.append(dict()) + self.tensors = [dict()] - def format_shard_names(self) -> list[os.PathLike[str]]: - pathobj = Path(self.path) + def format_shard_names(self, path: os.PathLike[str] | str | None = None) -> list[os.PathLike[str]]: + pathobj = Path(path) if self.split_arguments.split_style == SplitStyle.NONE: return [pathobj] @@ -174,14 +166,15 @@ class GGUFWriter: if self.path is not None: self.fout = [] - for fout in self.format_shard_names(): + for fout in self.format_shard_names(self.path): self.fout.append(open(fout, "wb")) self.state = WriterState.EMPTY - def print_plan(self) -> None: + def print_plan(self, path: os.PathLike[str] | str | None = None) -> None: logger.info("Writing the following files:") - for i in range(len(self.fout)): - logger.info(f"{self.fout[i].name}: n_tensors = {len(self.tensors[i])}, total_size = {GGUFWriter.format_n_bytes_to_str(GGUFWriter.get_tensors_total_size(self.tensors[i].values()))}") + filenames = self.format_shard_names(path) + for i in range(len(filenames)): + logger.info(f"{filenames[i]}: n_tensors = {len(self.tensors[i])}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in self.tensors[i].values()))}") if self.split_arguments.dry_run: logger.info("Dry run, not writing files") @@ -204,8 +197,8 @@ class GGUFWriter: def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None: self.verify_arguments() + self.print_plan(path) self.open_output_file(path) - self.print_plan() if self.state is not WriterState.EMPTY: raise ValueError(f'Expected output file to be empty, got {self.state}') @@ -215,13 +208,12 @@ class GGUFWriter: self.add_shard_kv_data() - for i in range(len(self.fout)): - fout = self.fout[i] + for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data): self._write_packed(fout, " None: @@ -246,12 +238,12 @@ class GGUFWriter: raise ValueError(f'Expected output file to contain KV data, got {self.state}') assert self.fout is not None - for i in range(len(self.fout)): - assert self.fout[i] is not None + for fout, tensors in zip(self.fout, self.tensors): + assert fout is not None ti_data = bytearray() offset_tensor = 0 - for name, ti in self.tensors[i].items(): + for name, ti in tensors.items(): ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) n_dims = len(ti.shape) ti_data += self._pack("I", n_dims) @@ -261,8 +253,8 @@ class GGUFWriter: ti_data += self._pack("Q", offset_tensor) offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) - self.fout[i].write(ti_data) - self.fout[i].flush() + fout.write(ti_data) + fout.flush() self.state = WriterState.TI_DATA def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: @@ -359,7 +351,7 @@ class GGUFWriter: and len(self.tensors[-1]) >= self.split_arguments.split_max_tensors) \ # or split when over size limit or (self.split_arguments.split_style == SplitStyle.SIZE \ - and GGUFWriter.get_tensors_total_size(self.tensors[-1].values()) + tensor_nbytes > self.split_arguments.split_max_size)): + and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_arguments.split_max_size)): self.tensors.append(dict()) @@ -424,7 +416,7 @@ class GGUFWriter: if progress: from tqdm import tqdm - total_bytes = GGUFWriter.get_tensors_total_size(self.tensors[i].values()) + total_bytes = sum(ti.nbytes for ti in self.tensors[i].values()) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) @@ -739,17 +731,6 @@ class GGUFWriter: assert fout is not None fout.write(self._pack(fmt, value, skip_pack_prefix)) - @staticmethod - def get_tensor_size(tensor) -> int: - try: - return tensor.data_type.elements_to_bytes(np.prod(tensor.shape)) - except AttributeError: # numpy ndarray[Any, Any] - return tensor.nbytes - - @staticmethod - def get_tensors_total_size(tensors) -> int: - return sum(GGUFWriter.get_tensor_size(ti) for ti in tensors) - @staticmethod def split_str_to_n_bytes(split_str: str) -> int: if split_str.endswith("K"):