use simplification from #7827

This commit is contained in:
Christian Zhou-Zheng 2024-06-08 23:14:26 -04:00
parent 666bb097a2
commit 03cc9bcbe8

View file

@ -61,7 +61,7 @@ class GGUFWriterSplit(GGUFWriter):
kv_data: KVTempData
split_arguments: SplitArguments
shards: list[Shard]
shard_writers: list[GGUFWriter]
shard_writers: list[tuple[GGUFWriter, os.PathLike[str]]]
def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments,
use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE
@ -115,17 +115,15 @@ class GGUFWriterSplit(GGUFWriter):
logger.info("Dry run, not writing files")
exit()
# we don't want to initialize GGUFWriters until now because they create files
for i, shard in enumerate(self.shards):
# add_architecture is used for consistency - examples/gguf_split doesn't add arch to all shards
writer = GGUFWriter(shard.path, self.arch, use_temp_file=self.use_temp_file,
writer = GGUFWriter(None, self.arch, use_temp_file=self.use_temp_file,
endianess=self.endianess, add_architecture=(i == 0))
# only the first shard needs all the KV data
if i == 0:
for key, (value, etype) in self.kv_data.items():
writer.add_key(key)
writer.add_val(value, etype)
writer.add_key_value(key, value, etype)
# add split metadata unless it's one file - small first shard splits even with SplitStyle.NONE
if self.split_arguments.split_style != SplitStyle.NONE or self.split_arguments.small_first_shard:
@ -141,14 +139,14 @@ class GGUFWriterSplit(GGUFWriter):
except IndexError:
break
self.shard_writers.append(writer)
self.shard_writers.append((writer, shard.path))
def write_header_to_file(self) -> None:
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected GGUFWriterSplit state to be EMPTY, got {self.state}')
for writer in self.shard_writers:
writer.write_header_to_file()
for (writer, path) in self.shard_writers:
writer.write_header_to_file(path)
self.state = WriterState.HEADER
@ -156,7 +154,7 @@ class GGUFWriterSplit(GGUFWriter):
if self.state is not WriterState.HEADER:
raise ValueError(f'Expected GGUFWriterSplit state to be HEADER, got {self.state}')
for writer in self.shard_writers:
for (writer, _) in self.shard_writers:
writer.write_kv_data_to_file()
self.state = WriterState.KV_DATA
@ -167,32 +165,21 @@ class GGUFWriterSplit(GGUFWriter):
running_total = self.total_tensors
for i in range(len(self.shard_writers)):
writer = self.shard_writers[i]
is_metadata = writer.ti_data_count == 0
writer = self.shard_writers[i][0]
is_metadata = len(writer.tensors) == 0
if is_metadata:
logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with metadata only")
else:
logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with {writer.ti_data_count}/{running_total} remaining tensors (of {self.total_tensors} total)")
running_total -= writer.ti_data_count
logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with {len(writer.tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
running_total -= len(writer.tensors)
writer.write_tensors_to_file(progress=(progress and not is_metadata))
del writer
self.state = WriterState.TI_DATA
# override add_key, add_val to handle kv data separately
def add_key(self, key: str) -> None:
self.recent_key = key
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
if self.recent_key is None:
raise ValueError("No key set for value")
self.kv_data[self.recent_key] = (val, vtype)
# need to handle arrays separately
def add_array(self, key: str, val: Sequence[Any]) -> None:
if not isinstance(val, Sequence):
raise ValueError(f'Expected a sequence for {key}, got {type(val)}')
self.kv_data[key] = (val, GGUFValueType.ARRAY)
# override add_key_value to handle kv data separately
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
self.kv_data[key] = (val, vtype)
def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
@ -218,7 +205,7 @@ class GGUFWriterSplit(GGUFWriter):
self.shards[-1].tensors.append((name, tensor, raw_dtype))
def close(self) -> None:
for writer in self.shard_writers:
for (writer, _) in self.shard_writers:
writer.close()
@staticmethod