From 03cc9bcbe80d957e93a70d4a62e1f878d024b8c8 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Sat, 8 Jun 2024 23:14:26 -0400 Subject: [PATCH] use simplification from #7827 --- gguf-py/gguf/gguf_writer_split.py | 45 +++++++++++-------------------- 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/gguf-py/gguf/gguf_writer_split.py b/gguf-py/gguf/gguf_writer_split.py index b4836737a..bc1e9443a 100644 --- a/gguf-py/gguf/gguf_writer_split.py +++ b/gguf-py/gguf/gguf_writer_split.py @@ -61,7 +61,7 @@ class GGUFWriterSplit(GGUFWriter): kv_data: KVTempData split_arguments: SplitArguments shards: list[Shard] - shard_writers: list[GGUFWriter] + shard_writers: list[tuple[GGUFWriter, os.PathLike[str]]] def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE @@ -115,17 +115,15 @@ class GGUFWriterSplit(GGUFWriter): logger.info("Dry run, not writing files") exit() - # we don't want to initialize GGUFWriters until now because they create files for i, shard in enumerate(self.shards): # add_architecture is used for consistency - examples/gguf_split doesn't add arch to all shards - writer = GGUFWriter(shard.path, self.arch, use_temp_file=self.use_temp_file, + writer = GGUFWriter(None, self.arch, use_temp_file=self.use_temp_file, endianess=self.endianess, add_architecture=(i == 0)) # only the first shard needs all the KV data if i == 0: for key, (value, etype) in self.kv_data.items(): - writer.add_key(key) - writer.add_val(value, etype) + writer.add_key_value(key, value, etype) # add split metadata unless it's one file - small first shard splits even with SplitStyle.NONE if self.split_arguments.split_style != SplitStyle.NONE or self.split_arguments.small_first_shard: @@ -141,14 +139,14 @@ class GGUFWriterSplit(GGUFWriter): except IndexError: break - self.shard_writers.append(writer) + self.shard_writers.append((writer, shard.path)) - def write_header_to_file(self) -> None: + def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None: if self.state is not WriterState.EMPTY: raise ValueError(f'Expected GGUFWriterSplit state to be EMPTY, got {self.state}') - for writer in self.shard_writers: - writer.write_header_to_file() + for (writer, path) in self.shard_writers: + writer.write_header_to_file(path) self.state = WriterState.HEADER @@ -156,7 +154,7 @@ class GGUFWriterSplit(GGUFWriter): if self.state is not WriterState.HEADER: raise ValueError(f'Expected GGUFWriterSplit state to be HEADER, got {self.state}') - for writer in self.shard_writers: + for (writer, _) in self.shard_writers: writer.write_kv_data_to_file() self.state = WriterState.KV_DATA @@ -167,32 +165,21 @@ class GGUFWriterSplit(GGUFWriter): running_total = self.total_tensors for i in range(len(self.shard_writers)): - writer = self.shard_writers[i] - is_metadata = writer.ti_data_count == 0 + writer = self.shard_writers[i][0] + is_metadata = len(writer.tensors) == 0 if is_metadata: logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with metadata only") else: - logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with {writer.ti_data_count}/{running_total} remaining tensors (of {self.total_tensors} total)") - running_total -= writer.ti_data_count + logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with {len(writer.tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)") + running_total -= len(writer.tensors) writer.write_tensors_to_file(progress=(progress and not is_metadata)) del writer self.state = WriterState.TI_DATA - # override add_key, add_val to handle kv data separately - def add_key(self, key: str) -> None: - self.recent_key = key - - def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None: - if self.recent_key is None: - raise ValueError("No key set for value") - self.kv_data[self.recent_key] = (val, vtype) - - # need to handle arrays separately - def add_array(self, key: str, val: Sequence[Any]) -> None: - if not isinstance(val, Sequence): - raise ValueError(f'Expected a sequence for {key}, got {type(val)}') - self.kv_data[key] = (val, GGUFValueType.ARRAY) + # override add_key_value to handle kv data separately + def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: + self.kv_data[key] = (val, vtype) def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, @@ -218,7 +205,7 @@ class GGUFWriterSplit(GGUFWriter): self.shards[-1].tensors.append((name, tensor, raw_dtype)) def close(self) -> None: - for writer in self.shard_writers: + for (writer, _) in self.shard_writers: writer.close() @staticmethod