refactor SplitStrategy to be a deque

Instead of having SplitStrategy have a `data` field that is a deque, just have SplitStrategy be a subclass of deque itself.
This commit is contained in:
Christian Zhou-Zheng 2024-05-24 00:28:48 -04:00
parent 3ff27efa89
commit 6b5c3753c8

View file

@ -55,31 +55,23 @@ class SplitArguments:
split_style: SplitStyle
def __init__(self, args: Namespace = None) -> None:
if args is None:
self.split = False
self.dry_run = False
self.small_first_shard = False
self.split_max_tensors = 0
self.split_max_size = 0
self.split_style = SplitStyle.NONE
else:
self.split = args.split
self.split_max_tensors = args.split_max_tensors
self.split_max_size = SplitStrategy.split_str_to_n_bytes(args.split_max_size) if args.split_max_size else None
self.dry_run = args.dry_run
self.small_first_shard = not args.large_first_shard
self.split_style = SplitStyle.NONE if not self.split \
else SplitStyle.TENSORS if self.split_max_tensors \
else SplitStyle.SIZE
self.split = args.split if args else False
self.split_max_tensors = args.split_max_tensors if args else 0
self.split_max_size = SplitStrategy.split_str_to_n_bytes(args.split_max_size) if args and args.split_max_size else 0
self.dry_run = args.dry_run if args else False
self.small_first_shard = not args.large_first_shard if args else False
self.split_style = SplitStyle.NONE if not self.split or not args \
else SplitStyle.TENSORS if self.split_max_tensors \
else SplitStyle.SIZE
class SplitStrategy:
class SplitStrategy(deque):
data: SplitTensorsPerFile
def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arch: str,
split_arguments: SplitArguments, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE,
):
self.data = deque()
super().__init__()
if split_arguments.split_style == SplitStyle.NONE:
self.append((fname_out, model, GGUFWriter(fname_out, arch, use_temp_file=use_temp_file, endianess=endianess)))
@ -121,15 +113,6 @@ class SplitStrategy:
outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + shard_offset, total_shards))
self.append((outname, deque(shard), GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess)))
def __len__(self):
return len(self.data)
def append(self, value: TensorTempData):
self.data.append(value)
def remove(self, item: TensorTempData):
self.data.remove(item)
@staticmethod
def get_tensor_size(tensor) -> int:
# we don't have the LazyTensor class here from convert.py but we can try
@ -216,18 +199,18 @@ class GGUFManager:
# only the first shard needs all the KV data
for key, (value, etype) in self.kv_data.items():
self.split_strategy.data[0][2].add_key(key)
self.split_strategy.data[0][2].add_val(value, etype)
self.split_strategy[0][2].add_key(key)
self.split_strategy[0][2].add_val(value, etype)
if self.split_arguments.split_style != SplitStyle.NONE:
for i, (_, _, writer) in enumerate(self.split_strategy.data):
for i, (_, _, writer) in enumerate(self.split_strategy):
writer.add_uint16(LLM_KV_SPLIT_NO, i)
writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards)
writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors)
# metadata/vocab only can write and return here
if meta_only:
for i, (_, _, writer) in enumerate(self.split_strategy.data):
for i, (_, _, writer) in enumerate(self.split_strategy):
writer.write_header_to_file()
writer.write_kv_data_to_file()
return
@ -235,14 +218,14 @@ class GGUFManager:
# tensor writing code starts here
print("\nWriting the following files:")
for (shard_path, shard_tensors, _) in self.split_strategy.data:
for (shard_path, shard_tensors, _) in self.split_strategy:
size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only"
print(f" {shard_path}: n_tensors = {len(shard_tensors) if shard_tensors else 0}, total_size = {size}")
if self.split_arguments.dry_run:
print("\nDry run, not writing files")
# instantiating GGUFWriters creates files
for name, _, _ in self.split_strategy.data:
for name, _, _ in self.split_strategy:
os.remove(name)
return
@ -251,7 +234,7 @@ class GGUFManager:
ct = 0
while True:
try:
(_, tensors, writer) = self.split_strategy.data.popleft()
(_, tensors, writer) = self.split_strategy.popleft()
except IndexError:
break
@ -340,7 +323,7 @@ class GGUFManager:
#self.write_padding(self.temp_file, tensor.nbytes)
def close(self) -> None:
for _, _, writer in self.split_strategy.data:
for _, _, writer in self.split_strategy:
writer.close()
def add_architecture(self) -> None: