use simplification from #7827
This commit is contained in:
parent
666bb097a2
commit
03cc9bcbe8
1 changed files with 16 additions and 29 deletions
|
@ -61,7 +61,7 @@ class GGUFWriterSplit(GGUFWriter):
|
|||
kv_data: KVTempData
|
||||
split_arguments: SplitArguments
|
||||
shards: list[Shard]
|
||||
shard_writers: list[GGUFWriter]
|
||||
shard_writers: list[tuple[GGUFWriter, os.PathLike[str]]]
|
||||
|
||||
def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments,
|
||||
use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE
|
||||
|
@ -115,17 +115,15 @@ class GGUFWriterSplit(GGUFWriter):
|
|||
logger.info("Dry run, not writing files")
|
||||
exit()
|
||||
|
||||
# we don't want to initialize GGUFWriters until now because they create files
|
||||
for i, shard in enumerate(self.shards):
|
||||
# add_architecture is used for consistency - examples/gguf_split doesn't add arch to all shards
|
||||
writer = GGUFWriter(shard.path, self.arch, use_temp_file=self.use_temp_file,
|
||||
writer = GGUFWriter(None, self.arch, use_temp_file=self.use_temp_file,
|
||||
endianess=self.endianess, add_architecture=(i == 0))
|
||||
|
||||
# only the first shard needs all the KV data
|
||||
if i == 0:
|
||||
for key, (value, etype) in self.kv_data.items():
|
||||
writer.add_key(key)
|
||||
writer.add_val(value, etype)
|
||||
writer.add_key_value(key, value, etype)
|
||||
|
||||
# add split metadata unless it's one file - small first shard splits even with SplitStyle.NONE
|
||||
if self.split_arguments.split_style != SplitStyle.NONE or self.split_arguments.small_first_shard:
|
||||
|
@ -141,14 +139,14 @@ class GGUFWriterSplit(GGUFWriter):
|
|||
except IndexError:
|
||||
break
|
||||
|
||||
self.shard_writers.append(writer)
|
||||
self.shard_writers.append((writer, shard.path))
|
||||
|
||||
def write_header_to_file(self) -> None:
|
||||
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
if self.state is not WriterState.EMPTY:
|
||||
raise ValueError(f'Expected GGUFWriterSplit state to be EMPTY, got {self.state}')
|
||||
|
||||
for writer in self.shard_writers:
|
||||
writer.write_header_to_file()
|
||||
for (writer, path) in self.shard_writers:
|
||||
writer.write_header_to_file(path)
|
||||
|
||||
self.state = WriterState.HEADER
|
||||
|
||||
|
@ -156,7 +154,7 @@ class GGUFWriterSplit(GGUFWriter):
|
|||
if self.state is not WriterState.HEADER:
|
||||
raise ValueError(f'Expected GGUFWriterSplit state to be HEADER, got {self.state}')
|
||||
|
||||
for writer in self.shard_writers:
|
||||
for (writer, _) in self.shard_writers:
|
||||
writer.write_kv_data_to_file()
|
||||
|
||||
self.state = WriterState.KV_DATA
|
||||
|
@ -167,32 +165,21 @@ class GGUFWriterSplit(GGUFWriter):
|
|||
|
||||
running_total = self.total_tensors
|
||||
for i in range(len(self.shard_writers)):
|
||||
writer = self.shard_writers[i]
|
||||
is_metadata = writer.ti_data_count == 0
|
||||
writer = self.shard_writers[i][0]
|
||||
is_metadata = len(writer.tensors) == 0
|
||||
if is_metadata:
|
||||
logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with metadata only")
|
||||
else:
|
||||
logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with {writer.ti_data_count}/{running_total} remaining tensors (of {self.total_tensors} total)")
|
||||
running_total -= writer.ti_data_count
|
||||
logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with {len(writer.tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
|
||||
running_total -= len(writer.tensors)
|
||||
writer.write_tensors_to_file(progress=(progress and not is_metadata))
|
||||
del writer
|
||||
|
||||
self.state = WriterState.TI_DATA
|
||||
|
||||
# override add_key, add_val to handle kv data separately
|
||||
def add_key(self, key: str) -> None:
|
||||
self.recent_key = key
|
||||
|
||||
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
|
||||
if self.recent_key is None:
|
||||
raise ValueError("No key set for value")
|
||||
self.kv_data[self.recent_key] = (val, vtype)
|
||||
|
||||
# need to handle arrays separately
|
||||
def add_array(self, key: str, val: Sequence[Any]) -> None:
|
||||
if not isinstance(val, Sequence):
|
||||
raise ValueError(f'Expected a sequence for {key}, got {type(val)}')
|
||||
self.kv_data[key] = (val, GGUFValueType.ARRAY)
|
||||
# override add_key_value to handle kv data separately
|
||||
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
|
||||
self.kv_data[key] = (val, vtype)
|
||||
|
||||
def add_tensor(
|
||||
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
||||
|
@ -218,7 +205,7 @@ class GGUFWriterSplit(GGUFWriter):
|
|||
self.shards[-1].tensors.append((name, tensor, raw_dtype))
|
||||
|
||||
def close(self) -> None:
|
||||
for writer in self.shard_writers:
|
||||
for (writer, _) in self.shard_writers:
|
||||
writer.close()
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue